speed up
Browse files- src/utils/paper_retriever.py +13 -13
src/utils/paper_retriever.py
CHANGED
@@ -188,15 +188,11 @@ class Retriever(object):
|
|
188 |
return similarity
|
189 |
|
190 |
def cal_related_score(
|
191 |
-
self,
|
192 |
):
|
193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
195 |
-
|
196 |
-
entities = self.api_helper.generate_entity_list(context)
|
197 |
-
origin_vector = self.embedding_model.encode(
|
198 |
-
context, convert_to_tensor=True, device=self.device
|
199 |
-
).unsqueeze(0)
|
200 |
context_embeddings = [
|
201 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
202 |
for paper_id in related_paper_id_list
|
@@ -275,11 +271,10 @@ class Retriever(object):
|
|
275 |
break
|
276 |
return paper_id_list
|
277 |
|
278 |
-
def cosine_similarity_search(self,
|
279 |
"""
|
280 |
return related paper: list
|
281 |
"""
|
282 |
-
embedding = self.embedding_model.encode(context)
|
283 |
result = self.paper_client.cosine_similarity_search(
|
284 |
embedding, k, type_name=type_name
|
285 |
)
|
@@ -506,8 +501,9 @@ class SNRetriever(Retriever):
|
|
506 |
|
507 |
def retrieve_paper(self, bg):
|
508 |
entities = []
|
|
|
509 |
sn_paper_id_list = self.cosine_similarity_search(
|
510 |
-
|
511 |
k=self.config.RETRIEVE.sn_retrieve_paper_num,
|
512 |
)
|
513 |
related_paper = set()
|
@@ -524,6 +520,7 @@ class SNRetriever(Retriever):
|
|
524 |
related_paper = list(related_paper)
|
525 |
logger.debug(f"paper num before filter: {len(related_paper)}")
|
526 |
result = {
|
|
|
527 |
"paper": related_paper,
|
528 |
"entities": entities,
|
529 |
"cocite_paper": list(cocite_id_set),
|
@@ -548,7 +545,7 @@ class SNRetriever(Retriever):
|
|
548 |
related_paper_id_list = retrieve_result["paper"]
|
549 |
retrieve_paper_num = len(related_paper_id_list)
|
550 |
_, _, score_all_dict = self.cal_related_score(
|
551 |
-
|
552 |
)
|
553 |
top_k_matrix = {}
|
554 |
recall = 0
|
@@ -626,8 +623,9 @@ class KGRetriever(Retriever):
|
|
626 |
retrieve_result = self.retrieve_paper(entities)
|
627 |
related_paper_id_list = retrieve_result["paper"]
|
628 |
retrieve_paper_num = len(related_paper_id_list)
|
|
|
629 |
_, _, score_all_dict = self.cal_related_score(
|
630 |
-
|
631 |
)
|
632 |
top_k_matrix = {}
|
633 |
recall = 0
|
@@ -668,8 +666,9 @@ class SNKGRetriever(Retriever):
|
|
668 |
|
669 |
def retrieve_paper(self, bg, entities):
|
670 |
sn_entities = []
|
|
|
671 |
sn_paper_id_list = self.cosine_similarity_search(
|
672 |
-
|
673 |
)
|
674 |
related_paper = set()
|
675 |
related_paper.update(sn_paper_id_list)
|
@@ -689,6 +688,7 @@ class SNKGRetriever(Retriever):
|
|
689 |
related_paper = related_paper.union(cocite_id_set)
|
690 |
related_paper = list(related_paper)
|
691 |
result = {
|
|
|
692 |
"paper": related_paper,
|
693 |
"entities": entities,
|
694 |
"cocite_paper": list(cocite_id_set),
|
@@ -717,7 +717,7 @@ class SNKGRetriever(Retriever):
|
|
717 |
retrieve_paper_num = len(related_paper_id_list)
|
718 |
logger.info("=== Begin cal related paper score ===")
|
719 |
_, _, score_all_dict = self.cal_related_score(
|
720 |
-
|
721 |
)
|
722 |
logger.info("=== End cal related paper score ===")
|
723 |
top_k_matrix = {}
|
|
|
188 |
return similarity
|
189 |
|
190 |
def cal_related_score(
|
191 |
+
self, embedding, related_paper_id_list, type_name="embedding"
|
192 |
):
|
193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
195 |
+
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
|
|
|
|
|
|
|
|
|
196 |
context_embeddings = [
|
197 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
198 |
for paper_id in related_paper_id_list
|
|
|
271 |
break
|
272 |
return paper_id_list
|
273 |
|
274 |
+
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
275 |
"""
|
276 |
return related paper: list
|
277 |
"""
|
|
|
278 |
result = self.paper_client.cosine_similarity_search(
|
279 |
embedding, k, type_name=type_name
|
280 |
)
|
|
|
501 |
|
502 |
def retrieve_paper(self, bg):
|
503 |
entities = []
|
504 |
+
embedding = self.embedding_model.encode(bg, device=self.device)
|
505 |
sn_paper_id_list = self.cosine_similarity_search(
|
506 |
+
embedding=embedding,
|
507 |
k=self.config.RETRIEVE.sn_retrieve_paper_num,
|
508 |
)
|
509 |
related_paper = set()
|
|
|
520 |
related_paper = list(related_paper)
|
521 |
logger.debug(f"paper num before filter: {len(related_paper)}")
|
522 |
result = {
|
523 |
+
"embedding": embedding,
|
524 |
"paper": related_paper,
|
525 |
"entities": entities,
|
526 |
"cocite_paper": list(cocite_id_set),
|
|
|
545 |
related_paper_id_list = retrieve_result["paper"]
|
546 |
retrieve_paper_num = len(related_paper_id_list)
|
547 |
_, _, score_all_dict = self.cal_related_score(
|
548 |
+
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
|
549 |
)
|
550 |
top_k_matrix = {}
|
551 |
recall = 0
|
|
|
623 |
retrieve_result = self.retrieve_paper(entities)
|
624 |
related_paper_id_list = retrieve_result["paper"]
|
625 |
retrieve_paper_num = len(related_paper_id_list)
|
626 |
+
embedding = self.embedding_model.encode(bg, device=self.device)
|
627 |
_, _, score_all_dict = self.cal_related_score(
|
628 |
+
embedding, related_paper_id_list=related_paper_id_list
|
629 |
)
|
630 |
top_k_matrix = {}
|
631 |
recall = 0
|
|
|
666 |
|
667 |
def retrieve_paper(self, bg, entities):
|
668 |
sn_entities = []
|
669 |
+
embedding = self.embedding_model.encode(bg, device=self.device)
|
670 |
sn_paper_id_list = self.cosine_similarity_search(
|
671 |
+
embedding, k=self.config.RETRIEVE.sn_num_for_entity
|
672 |
)
|
673 |
related_paper = set()
|
674 |
related_paper.update(sn_paper_id_list)
|
|
|
688 |
related_paper = related_paper.union(cocite_id_set)
|
689 |
related_paper = list(related_paper)
|
690 |
result = {
|
691 |
+
"embedding": embedding,
|
692 |
"paper": related_paper,
|
693 |
"entities": entities,
|
694 |
"cocite_paper": list(cocite_id_set),
|
|
|
717 |
retrieve_paper_num = len(related_paper_id_list)
|
718 |
logger.info("=== Begin cal related paper score ===")
|
719 |
_, _, score_all_dict = self.cal_related_score(
|
720 |
+
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
|
721 |
)
|
722 |
logger.info("=== End cal related paper score ===")
|
723 |
top_k_matrix = {}
|