|
"""Use spatial relations extracted from the parses.""" |
|
|
|
from typing import Dict, Any, Callable, List, Tuple, NamedTuple |
|
from numbers import Number |
|
from collections import defaultdict |
|
from overrides import overrides |
|
import numpy as np |
|
import spacy |
|
from spacy.tokens.token import Token |
|
from spacy.tokens.span import Span |
|
from argparse import Namespace |
|
|
|
from .ref_method import RefMethod |
|
from lattice import Product as L |
|
from heuristics import Heuristics |
|
from entity_extraction import Entity, expand_chunks |
|
|
|
|
|
def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity: |
|
"""If an entity represents a conjunction of two entities, pull them apart.""" |
|
head = ent.head.root |
|
if not any(child.text == "and" for child in head.children): |
|
return None |
|
for child in head.children: |
|
if child.i in chunks and head.i is not child.i: |
|
return Entity.extract(child, chunks, heuristics) |
|
return None |
|
|
|
|
|
class Parse(RefMethod): |
|
"""An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse. |
|
|
|
The process is as follows: |
|
1. Use spacy to parse the document. |
|
2. Extract a semantic entity tree from the parse. |
|
3. Execute the entity tree to yield a distribution over boxes.""" |
|
|
|
nlp = spacy.load('en_core_web_sm') |
|
|
|
def __init__(self, args: Namespace = None): |
|
self.args = args |
|
self.box_area_threshold = args.box_area_threshold |
|
self.baseline_threshold = args.baseline_threshold |
|
self.temperature = args.temperature |
|
self.superlative_head_only = args.superlative_head_only |
|
self.expand_chunks = args.expand_chunks |
|
self.branch = not args.parse_no_branch |
|
self.possessive_expand = not args.possessive_no_expand |
|
|
|
|
|
self.heuristics = Heuristics(args) |
|
|
|
|
|
self.counts = defaultdict(int) |
|
|
|
@overrides |
|
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: |
|
"""Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes.""" |
|
|
|
probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True) |
|
ori_probs = probs |
|
|
|
|
|
doc = self.nlp(caption) |
|
head = self.get_head(doc) |
|
chunks = self.get_chunks(doc) |
|
if self.expand_chunks: |
|
chunks = expand_chunks(doc, chunks) |
|
entity = Entity.extract(head, chunks, self.heuristics) |
|
|
|
|
|
if entity is None and len(list(doc.noun_chunks)) > 0: |
|
head = list(doc.noun_chunks)[0] |
|
entity = Entity.extract(head.root.head, chunks, self.heuristics) |
|
self.counts["n_0th_noun"] += 1 |
|
|
|
|
|
if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch): |
|
ent_probs, texts = self.execute_entity(entity, env, chunks) |
|
probs = L.meet(probs, ent_probs) |
|
else: |
|
texts = [caption] |
|
self.counts["n_full_expr"] += 1 |
|
|
|
if len(ori_probs) == 1: |
|
probs = ori_probs |
|
|
|
self.counts["n_total"] += 1 |
|
pred = np.argmax(probs) |
|
return { |
|
"probs": probs, |
|
"pred": pred, |
|
"box": env.boxes[pred], |
|
"texts": texts |
|
} |
|
|
|
def execute_entity(self, |
|
ent: Entity, |
|
env: "Environment", |
|
chunks: Dict[int, Span], |
|
root: bool = True, |
|
) -> np.ndarray: |
|
"""Execute an `Entity` tree recursively, yielding a distribution over boxes.""" |
|
self.counts["n_rec"] += 1 |
|
probs = [1, 1] |
|
head_probs = probs |
|
|
|
|
|
if len(probs) == 1 or len(env.boxes) == 1: |
|
return probs, [ent.text] |
|
|
|
m1, m2 = probs[:2] |
|
text = ent.text |
|
rel_probs = [] |
|
if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2: |
|
self.counts["n_rec_rel"] += 1 |
|
for tokens, ent2 in ent.relations: |
|
self.counts["n_rel"] += 1 |
|
rel = None |
|
|
|
for heuristic in self.heuristics.relations: |
|
if any(tok.text in heuristic.keywords for tok in tokens): |
|
rel = heuristic.callback(env) |
|
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 |
|
break |
|
|
|
if rel is not None: |
|
probs2 = self.execute_entity(ent2, env, chunks, root=False) |
|
events = L.meet(np.expand_dims(probs2, axis=0), rel) |
|
new_probs = L.join_reduce(events) |
|
rel_probs.append((ent2.text, new_probs, probs2)) |
|
continue |
|
|
|
|
|
rel = None |
|
for heuristic in self.heuristics.ternary_relations: |
|
if any(tok.text in heuristic.keywords for tok in tokens): |
|
rel = heuristic.callback(env) |
|
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 |
|
break |
|
if rel is not None: |
|
ent3 = get_conjunct(ent2, chunks, self.heuristics) |
|
if ent3 is not None: |
|
probs2 = self.execute_entity(ent2, env, chunks, root=False) |
|
probs2 = np.expand_dims(probs2, axis=[0, 2]) |
|
probs3 = self.execute_entity(ent3, env, chunks, root=False) |
|
probs3 = np.expand_dims(probs3, axis=[0, 1]) |
|
events = L.meet(L.meet(probs2, probs3), rel) |
|
new_probs = L.join_reduce(L.join_reduce(events)) |
|
probs = L.meet(probs, new_probs) |
|
continue |
|
|
|
if not self.args.no_possessive: |
|
if self.possessive_expand: |
|
text = ent.expand(ent2.head) |
|
else: |
|
text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}' |
|
|
|
probs = self._filter(text, env, root=root) |
|
texts = [text] |
|
return_probs = [(probs.tolist(), probs.tolist())] |
|
for (ent2_text, new_probs, ent2_only_probs) in rel_probs: |
|
probs = L.meet(probs, new_probs) |
|
probs /= probs.sum() |
|
texts.append(ent2_text) |
|
return_probs.append((probs.tolist(), ent2_only_probs.tolist())) |
|
|
|
|
|
m1, m2 = probs[(-probs).argsort()[:2]] |
|
if m1 < self.baseline_threshold * m2: |
|
self.counts["n_rec_sup"] += 1 |
|
for tokens in ent.superlatives: |
|
self.counts["n_sup"] += 1 |
|
sup = None |
|
for heuristic_index, heuristic in enumerate(self.heuristics.superlatives): |
|
if any(tok.text in heuristic.keywords for tok in tokens): |
|
texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords])) |
|
sup = heuristic.callback(env) |
|
self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1 |
|
break |
|
if sup is not None: |
|
|
|
precond = head_probs if self.superlative_head_only else probs |
|
probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1) |
|
probs = probs / probs.sum() |
|
return_probs.append((probs.tolist(), None)) |
|
|
|
if root: |
|
assert len(texts) == len(return_probs) |
|
return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values())) |
|
return probs |
|
|
|
def get_head(self, doc) -> Token: |
|
"""Return the token that is the head of the dependency parse. """ |
|
for token in doc: |
|
if token.head.i == token.i: |
|
return token |
|
return None |
|
|
|
def get_chunks(self, doc) -> Dict[int, Any]: |
|
"""Return a dictionary mapping sentence indices to their noun chunk.""" |
|
chunks = {} |
|
for chunk in doc.noun_chunks: |
|
for idx in range(chunk.start, chunk.end): |
|
chunks[idx] = chunk |
|
return chunks |
|
|
|
@overrides |
|
def get_stats(self) -> Dict[str, Number]: |
|
"""Summary statistics that have been tracked on this object.""" |
|
stats = dict(self.counts) |
|
n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_")) |
|
n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_")) |
|
stats.update({ |
|
"p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9), |
|
"p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9), |
|
"p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9), |
|
"p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9), |
|
"p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9), |
|
"p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9), |
|
"avg_rec": self.counts["n_rec"] / self.counts["n_total"], |
|
}) |
|
return stats |
|
|
|
def _filter(self, |
|
caption: str, |
|
env: "Environment", |
|
root: bool = False, |
|
expand: float = None, |
|
) -> np.ndarray: |
|
"""Wrap a filter call in a consistent way for all recursions.""" |
|
kwargs = { |
|
"softmax": not self.args.sigmoid, |
|
"temperature": self.args.temperature, |
|
} |
|
if root: |
|
return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs) |
|
else: |
|
return env.filter(caption, **kwargs) |
|
|