lihuigu commited on
Commit
b926e53
·
1 Parent(s): 69e60be
Files changed (1) hide show
  1. src/utils/paper_retriever.py +13 -13
src/utils/paper_retriever.py CHANGED
@@ -188,15 +188,11 @@ class Retriever(object):
188
  return similarity
189
 
190
  def cal_related_score(
191
- self, context, related_paper_id_list, entities=None, type_name="embedding"
192
  ):
193
  score_1 = np.zeros((len(related_paper_id_list)))
194
  score_2 = np.zeros((len(related_paper_id_list)))
195
- if entities is None:
196
- entities = self.api_helper.generate_entity_list(context)
197
- origin_vector = self.embedding_model.encode(
198
- context, convert_to_tensor=True, device=self.device
199
- ).unsqueeze(0)
200
  context_embeddings = [
201
  self.paper_client.get_paper_attribute(paper_id, type_name)
202
  for paper_id in related_paper_id_list
@@ -275,11 +271,10 @@ class Retriever(object):
275
  break
276
  return paper_id_list
277
 
278
- def cosine_similarity_search(self, context, k=1, type_name="embedding"):
279
  """
280
  return related paper: list
281
  """
282
- embedding = self.embedding_model.encode(context)
283
  result = self.paper_client.cosine_similarity_search(
284
  embedding, k, type_name=type_name
285
  )
@@ -506,8 +501,9 @@ class SNRetriever(Retriever):
506
 
507
  def retrieve_paper(self, bg):
508
  entities = []
 
509
  sn_paper_id_list = self.cosine_similarity_search(
510
- context=bg,
511
  k=self.config.RETRIEVE.sn_retrieve_paper_num,
512
  )
513
  related_paper = set()
@@ -524,6 +520,7 @@ class SNRetriever(Retriever):
524
  related_paper = list(related_paper)
525
  logger.debug(f"paper num before filter: {len(related_paper)}")
526
  result = {
 
527
  "paper": related_paper,
528
  "entities": entities,
529
  "cocite_paper": list(cocite_id_set),
@@ -548,7 +545,7 @@ class SNRetriever(Retriever):
548
  related_paper_id_list = retrieve_result["paper"]
549
  retrieve_paper_num = len(related_paper_id_list)
550
  _, _, score_all_dict = self.cal_related_score(
551
- bg, related_paper_id_list=related_paper_id_list, entities=entities
552
  )
553
  top_k_matrix = {}
554
  recall = 0
@@ -626,8 +623,9 @@ class KGRetriever(Retriever):
626
  retrieve_result = self.retrieve_paper(entities)
627
  related_paper_id_list = retrieve_result["paper"]
628
  retrieve_paper_num = len(related_paper_id_list)
 
629
  _, _, score_all_dict = self.cal_related_score(
630
- bg, related_paper_id_list=related_paper_id_list, entities=entities
631
  )
632
  top_k_matrix = {}
633
  recall = 0
@@ -668,8 +666,9 @@ class SNKGRetriever(Retriever):
668
 
669
  def retrieve_paper(self, bg, entities):
670
  sn_entities = []
 
671
  sn_paper_id_list = self.cosine_similarity_search(
672
- context=bg, k=self.config.RETRIEVE.sn_num_for_entity
673
  )
674
  related_paper = set()
675
  related_paper.update(sn_paper_id_list)
@@ -689,6 +688,7 @@ class SNKGRetriever(Retriever):
689
  related_paper = related_paper.union(cocite_id_set)
690
  related_paper = list(related_paper)
691
  result = {
 
692
  "paper": related_paper,
693
  "entities": entities,
694
  "cocite_paper": list(cocite_id_set),
@@ -717,7 +717,7 @@ class SNKGRetriever(Retriever):
717
  retrieve_paper_num = len(related_paper_id_list)
718
  logger.info("=== Begin cal related paper score ===")
719
  _, _, score_all_dict = self.cal_related_score(
720
- bg, related_paper_id_list=related_paper_id_list, entities=entities
721
  )
722
  logger.info("=== End cal related paper score ===")
723
  top_k_matrix = {}
 
188
  return similarity
189
 
190
  def cal_related_score(
191
+ self, embedding, related_paper_id_list, type_name="embedding"
192
  ):
193
  score_1 = np.zeros((len(related_paper_id_list)))
194
  score_2 = np.zeros((len(related_paper_id_list)))
195
+ origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
 
 
 
 
196
  context_embeddings = [
197
  self.paper_client.get_paper_attribute(paper_id, type_name)
198
  for paper_id in related_paper_id_list
 
271
  break
272
  return paper_id_list
273
 
274
+ def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
275
  """
276
  return related paper: list
277
  """
 
278
  result = self.paper_client.cosine_similarity_search(
279
  embedding, k, type_name=type_name
280
  )
 
501
 
502
  def retrieve_paper(self, bg):
503
  entities = []
504
+ embedding = self.embedding_model.encode(bg, device=self.device)
505
  sn_paper_id_list = self.cosine_similarity_search(
506
+ embedding=embedding,
507
  k=self.config.RETRIEVE.sn_retrieve_paper_num,
508
  )
509
  related_paper = set()
 
520
  related_paper = list(related_paper)
521
  logger.debug(f"paper num before filter: {len(related_paper)}")
522
  result = {
523
+ "embedding": embedding,
524
  "paper": related_paper,
525
  "entities": entities,
526
  "cocite_paper": list(cocite_id_set),
 
545
  related_paper_id_list = retrieve_result["paper"]
546
  retrieve_paper_num = len(related_paper_id_list)
547
  _, _, score_all_dict = self.cal_related_score(
548
+ retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
549
  )
550
  top_k_matrix = {}
551
  recall = 0
 
623
  retrieve_result = self.retrieve_paper(entities)
624
  related_paper_id_list = retrieve_result["paper"]
625
  retrieve_paper_num = len(related_paper_id_list)
626
+ embedding = self.embedding_model.encode(bg, device=self.device)
627
  _, _, score_all_dict = self.cal_related_score(
628
+ embedding, related_paper_id_list=related_paper_id_list
629
  )
630
  top_k_matrix = {}
631
  recall = 0
 
666
 
667
  def retrieve_paper(self, bg, entities):
668
  sn_entities = []
669
+ embedding = self.embedding_model.encode(bg, device=self.device)
670
  sn_paper_id_list = self.cosine_similarity_search(
671
+ embedding, k=self.config.RETRIEVE.sn_num_for_entity
672
  )
673
  related_paper = set()
674
  related_paper.update(sn_paper_id_list)
 
688
  related_paper = related_paper.union(cocite_id_set)
689
  related_paper = list(related_paper)
690
  result = {
691
+ "embedding": embedding,
692
  "paper": related_paper,
693
  "entities": entities,
694
  "cocite_paper": list(cocite_id_set),
 
717
  retrieve_paper_num = len(related_paper_id_list)
718
  logger.info("=== Begin cal related paper score ===")
719
  _, _, score_all_dict = self.cal_related_score(
720
+ retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
721
  )
722
  logger.info("=== End cal related paper score ===")
723
  top_k_matrix = {}