fix bug
Browse files- configs/datasets.yaml +2 -2
- src/retriever.py +1 -1
- src/utils/paper_client.py +3 -1
- src/utils/paper_retriever.py +5 -13
configs/datasets.yaml
CHANGED
@@ -14,10 +14,10 @@ RETRIEVE:
|
|
14 |
use_cluster_to_filter: False # 过滤器中使用聚类算法
|
15 |
cite_type: "all_cite_id_list"
|
16 |
limit_num: 100 # 限制entity对应的paper数量
|
17 |
-
sn_num_for_entity:
|
18 |
kg_jump_num: 1 # 跳数
|
19 |
kg_cover_num: 3 # entity重合数量
|
20 |
-
sum_paper_num:
|
21 |
sn_retrieve_paper_num: 55 # 通过SN检索到的文章
|
22 |
cocite_top_k: 1
|
23 |
need_normalize: True
|
|
|
14 |
use_cluster_to_filter: False # 过滤器中使用聚类算法
|
15 |
cite_type: "all_cite_id_list"
|
16 |
limit_num: 100 # 限制entity对应的paper数量
|
17 |
+
sn_num_for_entity: 3 # SN搜索的文章数量,扩充entity
|
18 |
kg_jump_num: 1 # 跳数
|
19 |
kg_cover_num: 3 # entity重合数量
|
20 |
+
sum_paper_num: 50 # 最多检索到的paper数量
|
21 |
sn_retrieve_paper_num: 55 # 通过SN检索到的文章
|
22 |
cocite_top_k: 1
|
23 |
need_normalize: True
|
src/retriever.py
CHANGED
@@ -26,7 +26,7 @@ def main(ctx):
|
|
26 |
@click.option(
|
27 |
"-c",
|
28 |
"--config-path",
|
29 |
-
default="
|
30 |
type=click.File(),
|
31 |
required=True,
|
32 |
help="Dataset configuration file in YAML",
|
|
|
26 |
@click.option(
|
27 |
"-c",
|
28 |
"--config-path",
|
29 |
+
default="./configs/datasets.yaml",
|
30 |
type=click.File(),
|
31 |
required=True,
|
32 |
help="Dataset configuration file in YAML",
|
src/utils/paper_client.py
CHANGED
@@ -130,7 +130,6 @@ class PaperClient:
|
|
130 |
related_entities.add(entity)
|
131 |
|
132 |
return list(related_entities)
|
133 |
-
|
134 |
related_entities = bfs_query(entity_name, n, k)
|
135 |
if entity_name in related_entities:
|
136 |
related_entities.remove(entity_name)
|
@@ -541,6 +540,7 @@ class PaperClient:
|
|
541 |
data = {"nodes": [], "relationships": []}
|
542 |
query = """
|
543 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
|
|
544 |
RETURN p, e, r
|
545 |
"""
|
546 |
results = graph.run(query)
|
@@ -572,6 +572,7 @@ class PaperClient:
|
|
572 |
WHERE p.venue_name='acl' and p.year='2024'
|
573 |
RETURN p
|
574 |
"""
|
|
|
575 |
results = graph.run(query)
|
576 |
for record in tqdm(results):
|
577 |
paper_node = record["p"]
|
@@ -581,6 +582,7 @@ class PaperClient:
|
|
581 |
"label": "Paper",
|
582 |
"properties": dict(paper_node)
|
583 |
})
|
|
|
584 |
# 去除重复节点
|
585 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
586 |
unique_nodes = []
|
|
|
130 |
related_entities.add(entity)
|
131 |
|
132 |
return list(related_entities)
|
|
|
133 |
related_entities = bfs_query(entity_name, n, k)
|
134 |
if entity_name in related_entities:
|
135 |
related_entities.remove(entity_name)
|
|
|
540 |
data = {"nodes": [], "relationships": []}
|
541 |
query = """
|
542 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
543 |
+
WHERE p.venue_name='iclr' and p.year='2024'
|
544 |
RETURN p, e, r
|
545 |
"""
|
546 |
results = graph.run(query)
|
|
|
572 |
WHERE p.venue_name='acl' and p.year='2024'
|
573 |
RETURN p
|
574 |
"""
|
575 |
+
"""
|
576 |
results = graph.run(query)
|
577 |
for record in tqdm(results):
|
578 |
paper_node = record["p"]
|
|
|
582 |
"label": "Paper",
|
583 |
"properties": dict(paper_node)
|
584 |
})
|
585 |
+
"""
|
586 |
# 去除重复节点
|
587 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
588 |
unique_nodes = []
|
src/utils/paper_retriever.py
CHANGED
@@ -124,7 +124,7 @@ class Retriever(object):
|
|
124 |
)
|
125 |
sum_paper_num = 0
|
126 |
for key, value in entity_paper_num_dict.items():
|
127 |
-
if sum_paper_num <=
|
128 |
sum_paper_num += value
|
129 |
new_entities.append(key)
|
130 |
elif (
|
@@ -188,35 +188,27 @@ class Retriever(object):
|
|
188 |
return similarity
|
189 |
|
190 |
def cal_related_score(
|
191 |
-
self, context, related_paper_id_list, entities=None, type_name="
|
192 |
):
|
193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
195 |
if entities is None:
|
196 |
entities = self.api_helper.generate_entity_list(context)
|
197 |
-
logger.debug("get entity from context: {}".format(entities))
|
198 |
origin_vector = self.embedding_model.encode(
|
199 |
context, convert_to_tensor=True, device=self.device
|
200 |
).unsqueeze(0)
|
201 |
-
|
202 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
203 |
for paper_id in related_paper_id_list
|
204 |
]
|
205 |
-
if len(
|
206 |
-
context_embeddings =
|
207 |
-
related_contexts,
|
208 |
-
batch_size=512,
|
209 |
-
convert_to_tensor=True,
|
210 |
-
device=self.device,
|
211 |
-
)
|
212 |
score_1 = torch.nn.functional.cosine_similarity(
|
213 |
origin_vector, context_embeddings
|
214 |
)
|
215 |
score_1 = score_1.cpu().numpy()
|
216 |
if self.config.RETRIEVE.need_normalize:
|
217 |
score_1 = score_1 / np.max(score_1)
|
218 |
-
# score_2 not enable
|
219 |
-
# if self.config.RETRIEVE.beta != 0:
|
220 |
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
221 |
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
222 |
score_all_dict = dict(
|
|
|
124 |
)
|
125 |
sum_paper_num = 0
|
126 |
for key, value in entity_paper_num_dict.items():
|
127 |
+
if sum_paper_num <= self.config.RETRIEVE.sum_paper_num:
|
128 |
sum_paper_num += value
|
129 |
new_entities.append(key)
|
130 |
elif (
|
|
|
188 |
return similarity
|
189 |
|
190 |
def cal_related_score(
|
191 |
+
self, context, related_paper_id_list, entities=None, type_name="embedding"
|
192 |
):
|
193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
195 |
if entities is None:
|
196 |
entities = self.api_helper.generate_entity_list(context)
|
|
|
197 |
origin_vector = self.embedding_model.encode(
|
198 |
context, convert_to_tensor=True, device=self.device
|
199 |
).unsqueeze(0)
|
200 |
+
context_embeddings = [
|
201 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
202 |
for paper_id in related_paper_id_list
|
203 |
]
|
204 |
+
if len(context_embeddings) > 0:
|
205 |
+
context_embeddings = torch.tensor(context_embeddings).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
206 |
score_1 = torch.nn.functional.cosine_similarity(
|
207 |
origin_vector, context_embeddings
|
208 |
)
|
209 |
score_1 = score_1.cpu().numpy()
|
210 |
if self.config.RETRIEVE.need_normalize:
|
211 |
score_1 = score_1 / np.max(score_1)
|
|
|
|
|
212 |
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
213 |
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
214 |
score_all_dict = dict(
|