Spaces:
Runtime error
Runtime error
File size: 1,699 Bytes
db24a4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from candidate import Candidate
from frame import Frame
from helper import MULTI_OBJECT_BONUS, get_hypernym_path
from nltk.corpus import wordnet as wn
class Searcher:
def __init__(self, trie):
self.trie = trie
def search(self, query: list[dict[str, str]], topk: int) -> list[Candidate]:
candidates: dict[str, float] = {}
for q in query:
this_candidates: list[Candidate] = []
object, amount = q['object'], q['amount']
hypernym_path = get_hypernym_path(object)
node_frames = self.trie.search(hypernym_path)
if amount == 'any':
this_candidates.extend([Candidate(node_frame.frame, node_frame.p_total) for node_frame in node_frames])
elif amount == int(amount):
this_candidates.extend([Candidate(node_frame.frame, node_frame.p_of(amount) * amount) for node_frame in node_frames])
else:
raise ValueError('Amount must be an integer or "any"')
for candidate in this_candidates:
if candidate.frame.id not in candidates:
candidates[candidate.frame.id] = candidate.score
else:
candidates[candidate.frame.id] += candidate.score + MULTI_OBJECT_BONUS
candidates = [Candidate(Frame(id=id), score) for id, score in candidates.items()]
candidates = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)
if len(candidates) > topk:
candidates = candidates[:topk]
return candidates
|