File size: 2,210 Bytes
8b513d0
d6504ae
8b513d0
f51d9c9
 
8b513d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24a93ab
8b513d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f51d9c9
 
 
 
 
 
 
 
 
 
 
8b513d0
 
 
 
24a93ab
8b513d0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from collections import Counter, defaultdict
from .DatabaseConnection import get_wikidata_instance

MAX_ITEMS_PREVIEW=20


class EntityCollection:

    def __init__(self, entities=[]):
        self.entities = entities

    def __iter__(self):
        for entity in self.entities:
            yield entity

    def __getitem__(self, item):
        return self.entities[item]

    def __len__(self):
        return len(self.entities)

    def append(self, entity):
        self.entities.append(entity)

    def get_categories(self, max_depth=1):
        categories = []
        for entity in self.entities:
            categories += entity.get_categories(max_depth)

        return categories

    def print_super_entities(self, max_depth=1, limit=10):
        wikidataInstance = get_wikidata_instance()

        all_categories = []
        category_to_entites = defaultdict(list)

        for e in self.entities:
            for category in e.get_categories(max_depth):
                category_to_entites[category].append(e)
                all_categories.append(category)

        counter = Counter()
        counter.update(all_categories)

        for category, frequency in counter.most_common(limit):
            print("{} ({}) : {}".format(wikidataInstance.get_entity_name(category), frequency,
                                        ','.join([str(e) for e in category_to_entites[category]])))

    def __repr__(self) -> str:
        preview_str="<EntityCollection ({} entities):".format(len(self))
        for index,entity_element in enumerate(self):
            if index>MAX_ITEMS_PREVIEW:
                preview_str+="\n...{} more".format(len(self)-MAX_ITEMS_PREVIEW)
                break
            preview_str+="\n-{}".format(entity_element.get_preview_string())
        
        preview_str+=">"
        return preview_str

    def pretty_print(self):
        for entity in self.entities:
            entity.pretty_print()

    def grouped_by_super_entities(self, max_depth=1):
        counter = Counter()
        counter.update(self.get_categories(max_depth))

        return counter

    def get_distinct_categories(self, max_depth=1):
        return list(set(self.get_categories(max_depth)))