File size: 5,575 Bytes
92e0882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional
import numpy as np
from spacy.tokens.token import Token
from spacy.tokens.span import Span

from lattice import Product as L

from heuristics import Heuristics

Rel = Tuple[List[Token], "Entity"]
Sup = List[Token]

DEFAULT_HEURISTICS = Heuristics()


def find_superlatives(tokens, heuristics) -> List[Sup]:
    """Modify and return a list of superlative tokens."""
    for heuristic in heuristics.superlatives:
        if any(tok.text in heuristic.keywords for tok in tokens):
            tokens.sort(key=lambda tok: tok.i)
            return [tokens]
    return []

def expand_chunks(doc, chunks):
    expanded = {}
    for key in chunks:
        chunk = chunks[key]
        start = chunk.start
        end = chunk.end
        for i in range(chunk.start-1, -1, -1):
            if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
                if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
                    start = i
        for i in range(chunk.end, len(doc)):
            if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
                if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
                    end = i+1
                else:
                    break
        expanded[key] = Span(doc=doc, start=start, end=end)
    return expanded

class Entity(NamedTuple):
    """Represents an entity with locative constraints extracted from the parse."""

    head: Span
    relations: List[Rel]
    superlatives: List[Sup]

    @classmethod
    def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity":
        """Extract entities from a spacy parse.

        Jointly recursive with `_get_rel_sups`."""
        if heuristics is None:
            heuristics = DEFAULT_HEURISTICS

        if head.i not in chunks:
            # Handles predicative cases.
            children = list(head.children)
            if children and children[0].i in chunks:
                head = children[0]
                # TODO: Also extract predicative relations.
            else:
                return None
        hchunk = chunks[head.i]
        rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics)
        return cls(hchunk, rels, sups)

    @classmethod
    def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]:
        hchunk = chunks[head.i]
        is_keyword = any(token.text in h.keywords for h in heuristics.relations)
        is_keyword |= token.text in heuristics.null_keywords

        # Found another entity head.
        if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword:
            tchunk = chunks[token.i]
            tokens.sort(key=lambda tok: tok.i)
            subhead = cls.extract(token, chunks, heuristics)
            return [(tokens, subhead)], []

        # End of a chain of modifiers.
        n_children = len(list(token.children))
        if n_children == 0:
            return [], find_superlatives(tokens + [token], heuristics)

        relations = []
        superlatives = []
        is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives)
        for child in token.children:
            if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]:
                if not any(child.text in h.keywords for h in heuristics.superlatives):
                    if n_children == 1:
                        # Catches "the goat on the left"
                        sups = find_superlatives(tokens + [token], heuristics)
                        superlatives.extend(sups)
                    continue
            new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens
            subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics)
            relations.extend(subrel)
            superlatives.extend(subsup)
        return relations, superlatives

    def expand(self, span: Span = None):
        tokens = [token for token in self.head]
        if span is None:
            span = [None]
        for target_token in span:
            include = False
            stack = [token for token in self.head]
            while len(stack) > 0:
                token = stack.pop()
                if token == target_token:
                    token2 = target_token.head
                    while token2.head != token2:
                        tokens.append(token2)
                        token2 = token2.head
                    tokens.append(token2)
                    stack = []
                    include = True
                if target_token is None or include:
                    tokens.append(token)
                for child in token.children:
                    stack.append(child)
        tokens = list(set(tokens))
        tokens = sorted(tokens, key=lambda x: x.i)
        return ' '.join([token.text for token in tokens])

    def __eq__(self, other: "Entity") -> bool:
        if self.text != other.text:
            return False
        if self.relations != other.relations:
            return False
        if self.superlatives != other.superlatives:
            return False
        return True

    @property
    def text(self) -> Text:
        """Get the text predicate associated with this entity."""
        return self.head.text