File size: 1,699 Bytes
db24a4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from candidate import Candidate
from frame import Frame
from helper import MULTI_OBJECT_BONUS, get_hypernym_path

from nltk.corpus import wordnet as wn


class Searcher:
    def __init__(self, trie):
        self.trie = trie
        
    def search(self, query: list[dict[str, str]], topk: int) -> list[Candidate]:
        candidates: dict[str, float] = {}
        for q in query:
            this_candidates: list[Candidate] = []
            object, amount = q['object'], q['amount']
            hypernym_path = get_hypernym_path(object)
            node_frames = self.trie.search(hypernym_path)
            if amount == 'any':
                this_candidates.extend([Candidate(node_frame.frame, node_frame.p_total) for node_frame in node_frames])
            elif amount == int(amount):
                this_candidates.extend([Candidate(node_frame.frame, node_frame.p_of(amount) * amount) for node_frame in node_frames])
            else:
                raise ValueError('Amount must be an integer or "any"')
            for candidate in this_candidates:
                if candidate.frame.id not in candidates:
                    candidates[candidate.frame.id] = candidate.score
                else: 
                    candidates[candidate.frame.id] += candidate.score + MULTI_OBJECT_BONUS
                
        candidates = [Candidate(Frame(id=id), score) for id, score in candidates.items()]
        candidates = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)
        if len(candidates) > topk:
            candidates = candidates[:topk]
        return candidates