File size: 1,529 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import List, Dict, Tuple


def filter_clusters(clusters: List, threshold: int = 1) -> List:
    """Filter clusters with mentions less than the specified threshold."""

    return [
        tuple(tuple(mention) for mention in cluster)
        for cluster_ind,cluster in enumerate(clusters)
        if len(cluster) >= threshold and cluster_ind != len(clusters) - 1   # last cluster is always removed.
    ]


def get_mention_to_cluster(clusters: List) -> Dict:
    """Get mention to cluster mapping."""

    clusters = [tuple(tuple(mention) for mention in cluster) for cluster in clusters]
    mention_to_cluster_dict = {}
    for cluster in clusters:
        for mention in cluster:
            mention_to_cluster_dict[mention] = cluster
    return mention_to_cluster_dict


def get_mention_to_cluster_idx(clusters: List) -> Dict:
    """Get mention to cluster idx mapping while filtering clustering."""

    clusters = [tuple(tuple(mention) for mention in cluster) for cluster in clusters]
    mention_to_cluster_dict = {}
    for cluster_idx, cluster in enumerate(clusters):
        for mention in cluster:
            mention_to_cluster_dict[mention] = cluster_idx
    return mention_to_cluster_dict


def is_aligned(span1: Tuple[int, int], span2: Tuple[int, int]) -> bool:
    """Return true if one of the span is a substring of the other span."""

    if span1[0] >= span2[0] and span1[1] <= span2[1]:
        return True
    if span2[0] >= span1[0] and span2[1] <= span1[1]:
        return True
    return False