lihuigu commited on
Commit
69e60be
·
1 Parent(s): 382638a
configs/datasets.yaml CHANGED
@@ -14,10 +14,10 @@ RETRIEVE:
14
  use_cluster_to_filter: False # 过滤器中使用聚类算法
15
  cite_type: "all_cite_id_list"
16
  limit_num: 100 # 限制entity对应的paper数量
17
- sn_num_for_entity: 5 # SN搜索的文章数量,扩充entity
18
  kg_jump_num: 1 # 跳数
19
  kg_cover_num: 3 # entity重合数量
20
- sum_paper_num: 30 # 最多检索到的paper数量
21
  sn_retrieve_paper_num: 55 # 通过SN检索到的文章
22
  cocite_top_k: 1
23
  need_normalize: True
 
14
  use_cluster_to_filter: False # 过滤器中使用聚类算法
15
  cite_type: "all_cite_id_list"
16
  limit_num: 100 # 限制entity对应的paper数量
17
+ sn_num_for_entity: 3 # SN搜索的文章数量,扩充entity
18
  kg_jump_num: 1 # 跳数
19
  kg_cover_num: 3 # entity重合数量
20
+ sum_paper_num: 50 # 最多检索到的paper数量
21
  sn_retrieve_paper_num: 55 # 通过SN检索到的文章
22
  cocite_top_k: 1
23
  need_normalize: True
src/retriever.py CHANGED
@@ -26,7 +26,7 @@ def main(ctx):
26
  @click.option(
27
  "-c",
28
  "--config-path",
29
- default="../configs/datasets.yaml",
30
  type=click.File(),
31
  required=True,
32
  help="Dataset configuration file in YAML",
 
26
  @click.option(
27
  "-c",
28
  "--config-path",
29
+ default="./configs/datasets.yaml",
30
  type=click.File(),
31
  required=True,
32
  help="Dataset configuration file in YAML",
src/utils/paper_client.py CHANGED
@@ -130,7 +130,6 @@ class PaperClient:
130
  related_entities.add(entity)
131
 
132
  return list(related_entities)
133
-
134
  related_entities = bfs_query(entity_name, n, k)
135
  if entity_name in related_entities:
136
  related_entities.remove(entity_name)
@@ -541,6 +540,7 @@ class PaperClient:
541
  data = {"nodes": [], "relationships": []}
542
  query = """
543
  MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
 
544
  RETURN p, e, r
545
  """
546
  results = graph.run(query)
@@ -572,6 +572,7 @@ class PaperClient:
572
  WHERE p.venue_name='acl' and p.year='2024'
573
  RETURN p
574
  """
 
575
  results = graph.run(query)
576
  for record in tqdm(results):
577
  paper_node = record["p"]
@@ -581,6 +582,7 @@ class PaperClient:
581
  "label": "Paper",
582
  "properties": dict(paper_node)
583
  })
 
584
  # 去除重复节点
585
  # data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
586
  unique_nodes = []
 
130
  related_entities.add(entity)
131
 
132
  return list(related_entities)
 
133
  related_entities = bfs_query(entity_name, n, k)
134
  if entity_name in related_entities:
135
  related_entities.remove(entity_name)
 
540
  data = {"nodes": [], "relationships": []}
541
  query = """
542
  MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
543
+ WHERE p.venue_name='iclr' and p.year='2024'
544
  RETURN p, e, r
545
  """
546
  results = graph.run(query)
 
572
  WHERE p.venue_name='acl' and p.year='2024'
573
  RETURN p
574
  """
575
+ """
576
  results = graph.run(query)
577
  for record in tqdm(results):
578
  paper_node = record["p"]
 
582
  "label": "Paper",
583
  "properties": dict(paper_node)
584
  })
585
+ """
586
  # 去除重复节点
587
  # data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
588
  unique_nodes = []
src/utils/paper_retriever.py CHANGED
@@ -124,7 +124,7 @@ class Retriever(object):
124
  )
125
  sum_paper_num = 0
126
  for key, value in entity_paper_num_dict.items():
127
- if sum_paper_num <= 100:
128
  sum_paper_num += value
129
  new_entities.append(key)
130
  elif (
@@ -188,35 +188,27 @@ class Retriever(object):
188
  return similarity
189
 
190
  def cal_related_score(
191
- self, context, related_paper_id_list, entities=None, type_name="motivation"
192
  ):
193
  score_1 = np.zeros((len(related_paper_id_list)))
194
  score_2 = np.zeros((len(related_paper_id_list)))
195
  if entities is None:
196
  entities = self.api_helper.generate_entity_list(context)
197
- logger.debug("get entity from context: {}".format(entities))
198
  origin_vector = self.embedding_model.encode(
199
  context, convert_to_tensor=True, device=self.device
200
  ).unsqueeze(0)
201
- related_contexts = [
202
  self.paper_client.get_paper_attribute(paper_id, type_name)
203
  for paper_id in related_paper_id_list
204
  ]
205
- if len(related_contexts) > 0:
206
- context_embeddings = self.embedding_model.encode(
207
- related_contexts,
208
- batch_size=512,
209
- convert_to_tensor=True,
210
- device=self.device,
211
- )
212
  score_1 = torch.nn.functional.cosine_similarity(
213
  origin_vector, context_embeddings
214
  )
215
  score_1 = score_1.cpu().numpy()
216
  if self.config.RETRIEVE.need_normalize:
217
  score_1 = score_1 / np.max(score_1)
218
- # score_2 not enable
219
- # if self.config.RETRIEVE.beta != 0:
220
  score_sn_dict = dict(zip(related_paper_id_list, score_1))
221
  score_en_dict = dict(zip(related_paper_id_list, score_2))
222
  score_all_dict = dict(
 
124
  )
125
  sum_paper_num = 0
126
  for key, value in entity_paper_num_dict.items():
127
+ if sum_paper_num <= self.config.RETRIEVE.sum_paper_num:
128
  sum_paper_num += value
129
  new_entities.append(key)
130
  elif (
 
188
  return similarity
189
 
190
  def cal_related_score(
191
+ self, context, related_paper_id_list, entities=None, type_name="embedding"
192
  ):
193
  score_1 = np.zeros((len(related_paper_id_list)))
194
  score_2 = np.zeros((len(related_paper_id_list)))
195
  if entities is None:
196
  entities = self.api_helper.generate_entity_list(context)
 
197
  origin_vector = self.embedding_model.encode(
198
  context, convert_to_tensor=True, device=self.device
199
  ).unsqueeze(0)
200
+ context_embeddings = [
201
  self.paper_client.get_paper_attribute(paper_id, type_name)
202
  for paper_id in related_paper_id_list
203
  ]
204
+ if len(context_embeddings) > 0:
205
+ context_embeddings = torch.tensor(context_embeddings).to(self.device)
 
 
 
 
 
206
  score_1 = torch.nn.functional.cosine_similarity(
207
  origin_vector, context_embeddings
208
  )
209
  score_1 = score_1.cpu().numpy()
210
  if self.config.RETRIEVE.need_normalize:
211
  score_1 = score_1 / np.max(score_1)
 
 
212
  score_sn_dict = dict(zip(related_paper_id_list, score_1))
213
  score_en_dict = dict(zip(related_paper_id_list, score_2))
214
  score_all_dict = dict(