Spaces:
Sleeping
Sleeping
File size: 2,272 Bytes
98e2ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from coref_utils.utils import get_mention_to_cluster_idx
from collections import defaultdict
def get_gt_actions(pred_mentions, document, mem_type_config, mapped_mentions=[]):
if "clusters" in document:
# Ground truth is avaliable
gt_clusters = document["clusters"]
return get_actions_unbounded_fast(pred_mentions, gt_clusters, mapped_mentions)
else:
# Don't have ground truth clusters i.e. running it in the wild
# Generate dummy actions
return [(-1, "i")] * len(pred_mentions)
def action_sequences_to_clusters(actions, mentions, num_major_entities):
cell_to_clusters = defaultdict(list)
for mention, (cell_idx, action_type) in zip(mentions, actions):
if action_type == "i":
continue
elif action_type == "o":
cell_to_clusters[num_major_entities].append(mention)
else:
cell_to_clusters[cell_idx].append(mention)
clusters = [[] for _ in range(num_major_entities + 1)]
for cell_idx, cluster in cell_to_clusters.items():
clusters[cell_idx] = cluster
return clusters
def get_cluster_to_cell(mapped_mentions, mention_to_cluster):
cluster_to_cell = {}
cell_counter = 0
for mention in mapped_mentions:
if tuple(mention) not in mention_to_cluster:
print("Error: Mention not in mentions", tuple(mention))
else:
mention_cluster = mention_to_cluster[tuple(mention)]
if mention_cluster not in cluster_to_cell:
cluster_to_cell[mention_cluster] = cell_counter
cell_counter += 1
return cluster_to_cell
def get_actions_unbounded_fast(pred_mentions, gt_clusters, mapped_mentions=[]):
actions = []
num_clusters = len(gt_clusters)
mention_to_cluster = get_mention_to_cluster_idx(gt_clusters)
for idx, mention in enumerate(pred_mentions):
if tuple(mention) not in mention_to_cluster:
actions.append((num_clusters - 1, "o"))
else:
mention_cluster = mention_to_cluster[tuple(mention)]
if mention_cluster == num_clusters - 1:
actions.append((mention_cluster, "o"))
else:
actions.append((mention_cluster, "c"))
return actions
|