James Kelly
commited on
Commit
•
8155451
1
Parent(s):
23d06eb
cloned most of ctmatch into this spaces repo... it will have to handle the data too, we'll see. using ctmatch requirements.txt
Browse files- app.py +27 -0
- ctmatch/__init__.py +0 -0
- ctmatch/ct_data_paths.py +92 -0
- ctmatch/ctmatch_prep.py +351 -0
- ctmatch/dataprep.py +152 -0
- ctmatch/eda.py +114 -0
- ctmatch/evaluator.py +154 -0
- ctmatch/match.py +333 -0
- ctmatch/models/classifier_model.py +396 -0
- ctmatch/models/gen_model.py +83 -0
- ctmatch/pipeconfig.py +42 -0
- ctmatch/pipetopic.py +10 -0
- ctmatch/scripts/build_combined_data.py +76 -0
- ctmatch/scripts/gen_categories.py +92 -0
- ctmatch/scripts/get_web_data.py +14 -0
- ctmatch/scripts/split_files.py +37 -0
- ctmatch/scripts/vis_script.py +334 -0
- ctmatch/utils/__init__.py +0 -0
- ctmatch/utils/ctmatch_utils.py +133 -0
- ctmatch/utils/eval_utils.py +91 -0
- requirements.txt +23 -0
app.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from ctmatch.match import CTMatch, PipeConfig
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
pipe_config = PipeConfig(
|
7 |
+
classifier_model_checkpoint='semaj83/scibert_finetuned_pruned_ctmatch',
|
8 |
+
ir_setup=True,
|
9 |
+
filters=["svm", "classifier"],
|
10 |
+
)
|
11 |
+
|
12 |
+
CTM = CTMatch(pipe_config)
|
13 |
+
|
14 |
+
|
15 |
+
def ctmatch_web_api(topic_query: str) -> str:
|
16 |
+
return '\n\n'.join([f"{nid}: {txt}" for nid, txt in CTM.match_pipeline(topic_query, top_k=5)])
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
|
21 |
+
with gr.Blocks(css=".gradio-container {background-color: #00CED1}") as demo:
|
22 |
+
name = gr.Textbox(lines=5, label="patient description", placeholder="Patient is a 45-year-old man with a history of anaplastic astrocytoma...")
|
23 |
+
output = gr.Textbox(lines=10, label="matching trials")
|
24 |
+
greet_btn = gr.Button("match")
|
25 |
+
greet_btn.click(fn=ctmatch_web_api, inputs=name, outputs=output, api_name="match")
|
26 |
+
|
27 |
+
demo.queue().launch(share=True, debug=True)
|
ctmatch/__init__.py
ADDED
File without changes
|
ctmatch/ct_data_paths.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
TREC_REL_PATH = "/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_21_judgments.txt"
|
5 |
+
KZ_REL_PATH = "/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/qrels-clinical_trials.txt"
|
6 |
+
|
7 |
+
TREC_RELLED_TOPIC_PATH = "/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec21_topics.jsonl"
|
8 |
+
KZ_RELLED_TOPIC_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_topics.jsonl'
|
9 |
+
|
10 |
+
KZ_DOC_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/clinicaltrials.gov-16_dec_2015.zip'
|
11 |
+
KZ_PROCESSED_DOC_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_docs.jsonl'
|
12 |
+
|
13 |
+
TREC_ML_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_data.jsonl'
|
14 |
+
KZ_ML_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/kz_data.jsonl'
|
15 |
+
|
16 |
+
|
17 |
+
def get_data_tuples(trec_or_kz: str = 'trec') -> List[Tuple[str, str]]:
|
18 |
+
if trec_or_kz == 'trec':
|
19 |
+
return get_trec_doc_data_tuples(), get_trec_topic_data_tuples()
|
20 |
+
return get_kz_doc_data_tuples(), get_kz_topic_data_tuples()
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
25 |
+
# data from TREC clinical track 2021 & 2022
|
26 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
27 |
+
|
28 |
+
|
29 |
+
def get_trec_doc_data_tuples() -> List[Tuple[str]]:
|
30 |
+
trec22_pt1_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part1.zip'
|
31 |
+
trec_pt1_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part1.jsonl'
|
32 |
+
|
33 |
+
trec22_pt2_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part2.zip'
|
34 |
+
trec_pt2_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part2.jsonl'
|
35 |
+
|
36 |
+
trec22_pt3_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part3.zip'
|
37 |
+
trec_pt3_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part3.jsonl'
|
38 |
+
|
39 |
+
trec22_pt4_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part4.zip'
|
40 |
+
trec_pt4_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part4.jsonl'
|
41 |
+
|
42 |
+
trec22_pt5_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part5.zip'
|
43 |
+
trec_pt5_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part5.jsonl'
|
44 |
+
|
45 |
+
trec_doc_data_tuples = [
|
46 |
+
(trec22_pt1_docs, trec_pt1_target),
|
47 |
+
(trec22_pt2_docs, trec_pt2_target),
|
48 |
+
(trec22_pt3_docs, trec_pt3_target),
|
49 |
+
(trec22_pt4_docs, trec_pt4_target),
|
50 |
+
(trec22_pt5_docs, trec_pt5_target)
|
51 |
+
]
|
52 |
+
|
53 |
+
return trec_doc_data_tuples
|
54 |
+
|
55 |
+
|
56 |
+
def get_trec_topic_data_tuples() -> List[Tuple[str]]:
|
57 |
+
trec21_topic_path = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_21_topics.xml'
|
58 |
+
trec21_topic_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec21_topics.jsonl'
|
59 |
+
trec22_topic_path = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_22_topics.xml'
|
60 |
+
trec22_topic_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_topics.jsonl'
|
61 |
+
|
62 |
+
trec_topic_data_tuples = [
|
63 |
+
(trec21_topic_path, trec21_topic_target),
|
64 |
+
(trec22_topic_path, trec22_topic_target)
|
65 |
+
]
|
66 |
+
return trec_topic_data_tuples
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
72 |
+
# data from Koontz, et al. (2016)
|
73 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
74 |
+
def get_kz_doc_data_tuples() -> List[Tuple[str]]:
|
75 |
+
# kz_doc_data_tuples = []
|
76 |
+
# for i in range(1, 18):
|
77 |
+
# kz_doc_path = f'/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/kz_doc_splits/kz_doc_split{i}.zip'
|
78 |
+
# kz_doc_target = f'/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_doc_split{i}.jsonl'
|
79 |
+
# kz_doc_data_tuples.append((kz_doc_path, kz_doc_target))
|
80 |
+
kz_docs = KZ_DOC_PATH
|
81 |
+
kz_docs_target = KZ_PROCESSED_DOC_PATH
|
82 |
+
return [(kz_docs, kz_docs_target)]
|
83 |
+
|
84 |
+
#return kz_doc_data_tuples
|
85 |
+
|
86 |
+
def get_kz_topic_data_tuples() -> List[Tuple[str]]:
|
87 |
+
kz_topic_desc_path = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/topics-2014_2015-description.topics'
|
88 |
+
kz_topic_target = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_topics.jsonl'
|
89 |
+
kz_topic_data_tuples = [
|
90 |
+
(kz_topic_desc_path, kz_topic_target)
|
91 |
+
]
|
92 |
+
return kz_topic_data_tuples
|
ctmatch/ctmatch_prep.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
3 |
+
import ct_data_paths as ctpaths
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import json
|
7 |
+
|
8 |
+
from proc import CTConfig, CTProc, CTDocument, CTTopic
|
9 |
+
from scripts.vis_scripts import analyze_test_rels
|
10 |
+
from ctproc_ctmatch_utils import get_processed_data, truncate
|
11 |
+
import ctproc_eda as eda
|
12 |
+
|
13 |
+
LLM_END_PROMPT: str = "Revelance score (0, 1, or 2) : [CLS] "
|
14 |
+
|
15 |
+
class DataConfig(NamedTuple):
|
16 |
+
save_path: str
|
17 |
+
trec_or_kz: str = 'trec'
|
18 |
+
filtered_topic_keys: Set[str] = {'id', 'text_sents', 'age', 'gender'}
|
19 |
+
filtered_doc_keys: Set[str] = {'id', 'elig_min_age', 'elig_max_age', 'elig_gender', 'condition', 'elig_crit'}
|
20 |
+
max_topic_len: Optional[int] = None
|
21 |
+
max_inc_len: Optional[int] = None
|
22 |
+
max_exc_len: Optional[int] = None
|
23 |
+
prepend_elig_age: bool = True
|
24 |
+
prepend_elig_gender: bool = True
|
25 |
+
include_only: bool = False
|
26 |
+
downsample_zeros_n: Optional[int] = None
|
27 |
+
sep: str = '[SEP]'
|
28 |
+
llm_prep: bool = False
|
29 |
+
first_n_only: Optional[int] = None
|
30 |
+
convert_snli: bool = False
|
31 |
+
infer_category_model: Optional[str] = None
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def proc_docs_and_topics(trec_or_kz: str = 'trec') -> Tuple[Dict[str, Dict[str, str]], Dict[str, Dict[str, str]]]:
|
37 |
+
|
38 |
+
doc_tuples, topic_tuples = ctpaths.get_data_tuples(trec_or_kz)
|
39 |
+
|
40 |
+
id2topic = dict()
|
41 |
+
for topic_source, topic_target in topic_tuples:
|
42 |
+
id2topic.update(proc_topics(topic_source, topic_target, trec_or_kz=trec_or_kz))
|
43 |
+
print(f"processed {trec_or_kz} topic source: {topic_source}, and wrote to {topic_target}")
|
44 |
+
|
45 |
+
id2doc = dict()
|
46 |
+
for doc_source, doc_target in doc_tuples:
|
47 |
+
id2doc.update(proc_docs(doc_source, doc_target))
|
48 |
+
print(f"processed {trec_or_kz} doc source: {doc_source}, and wrote to {doc_target}")
|
49 |
+
|
50 |
+
|
51 |
+
return id2topic, id2doc
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def proc_docs(doc_path: str, output_path: str) -> Dict[str, CTDocument]:
|
58 |
+
|
59 |
+
ct_config = CTConfig(
|
60 |
+
data_path=doc_path,
|
61 |
+
write_file=output_path,
|
62 |
+
nlp=True
|
63 |
+
)
|
64 |
+
|
65 |
+
cp = CTProc(ct_config)
|
66 |
+
id2doc = {res.id : res for res in cp.process_data()}
|
67 |
+
return id2doc
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def proc_topics(topic_path: str, output_path: str, trec_or_kz: str = 'trec') -> Dict[str, CTTopic]:
|
72 |
+
|
73 |
+
ct_config = CTConfig(
|
74 |
+
data_path=topic_path,
|
75 |
+
write_file=output_path,
|
76 |
+
nlp=True,
|
77 |
+
is_topic=True,
|
78 |
+
trec_or_kz=trec_or_kz
|
79 |
+
)
|
80 |
+
|
81 |
+
cp = CTProc(ct_config)
|
82 |
+
id2topic = {res.id : res for res in cp.process_data()}
|
83 |
+
return id2topic
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
def filter_doc_for_ir(doc, dconfig) -> Dict[str, List[str]]:
|
88 |
+
new_doc = dict()
|
89 |
+
new_doc['id'] = doc['id']
|
90 |
+
new_doc['text'] = prep_doc_text(doc, dconfig)
|
91 |
+
return new_doc
|
92 |
+
|
93 |
+
|
94 |
+
def prep_ir_dataset(dconfig: DataConfig):
|
95 |
+
# need a file of all docs with their
|
96 |
+
# 1. ids,
|
97 |
+
# 2. combined text...
|
98 |
+
# 3.
|
99 |
+
|
100 |
+
# get path to processed docs
|
101 |
+
doc_tuples, _ = ctpaths.get_data_tuples(dconfig.trec_or_kz)
|
102 |
+
|
103 |
+
# get all processed docs
|
104 |
+
id2doc = dict()
|
105 |
+
for _, processed_doc_path in doc_tuples:
|
106 |
+
print(f"getting docs from {processed_doc_path}")
|
107 |
+
for doc in get_processed_data(processed_doc_path):
|
108 |
+
doc = filter_doc_for_ir(doc, dconfig)
|
109 |
+
doc['category'] = np.asarray(sorted(doc['category']).values()) # makes a consistently ordered category vector
|
110 |
+
id2doc[doc.id] = doc
|
111 |
+
return id2doc
|
112 |
+
|
113 |
+
|
114 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
115 |
+
# pre-processing functions to save a form of triples for a particular model spec
|
116 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
117 |
+
|
118 |
+
def prep_fine_tuning_dataset(
|
119 |
+
dconfig: DataConfig
|
120 |
+
) -> None:
|
121 |
+
"""
|
122 |
+
trec_or_kz: 'trec' or 'kz'
|
123 |
+
desc: create dict of triplets of topic, doc, relevancy scores,
|
124 |
+
save into a single jsonl file
|
125 |
+
"""
|
126 |
+
print(f"trec_or_kz: {dconfig.trec_or_kz}")
|
127 |
+
topic_path, rel_path = get_topic_and_rel_path(dconfig.trec_or_kz)
|
128 |
+
|
129 |
+
|
130 |
+
# get set of all relevant doc ids
|
131 |
+
rel_type_dict, rel_dict, all_qrelled_docs = analyze_test_rels(rel_path)
|
132 |
+
|
133 |
+
# get path to processed docs (already got topic path)
|
134 |
+
doc_tuples, _ = ctpaths.get_data_tuples(dconfig.trec_or_kz)
|
135 |
+
|
136 |
+
# get mappings of doc ids to doc dicts and topic ids to topic dicts
|
137 |
+
id2doc, id2topic = get_doc_and_topic_mappings(all_qrelled_docs, doc_tuples, topic_path)
|
138 |
+
print(len(id2doc), len(all_qrelled_docs))
|
139 |
+
|
140 |
+
missing_docs = set()
|
141 |
+
skipped = 0
|
142 |
+
|
143 |
+
# save combined triples of doc, topic, relevancy score
|
144 |
+
with open(dconfig.save_path, 'w') as f:
|
145 |
+
print(f"saving to: {dconfig.save_path}")
|
146 |
+
|
147 |
+
for topic_id in rel_dict:
|
148 |
+
for doc_id in rel_dict[topic_id]:
|
149 |
+
label = rel_dict[topic_id][doc_id]
|
150 |
+
if downsample_zero(label, rel_type_dict['0'], dconfig):
|
151 |
+
skipped += 1
|
152 |
+
continue
|
153 |
+
|
154 |
+
if doc_id in id2doc:
|
155 |
+
combined = create_combined_doc(
|
156 |
+
id2doc[doc_id],
|
157 |
+
id2topic[topic_id],
|
158 |
+
label,
|
159 |
+
dconfig=dconfig,
|
160 |
+
)
|
161 |
+
|
162 |
+
# save to file as jsonl
|
163 |
+
f.write(json.dumps(combined))
|
164 |
+
f.write('\n')
|
165 |
+
else:
|
166 |
+
missing_docs.add(doc_id)
|
167 |
+
|
168 |
+
|
169 |
+
print(f"number of docs missing: {len(missing_docs)}, number of zeros skipped: {skipped}")
|
170 |
+
for md in missing_docs:
|
171 |
+
print(md)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
def create_combined_doc(
|
176 |
+
doc, topic,
|
177 |
+
rel_score,
|
178 |
+
dconfig: DataConfig,
|
179 |
+
):
|
180 |
+
combined = dict()
|
181 |
+
|
182 |
+
# get filtered and truncated and SEP tokenized topic text
|
183 |
+
combined['topic'] = prep_topic_text(topic, dconfig)
|
184 |
+
|
185 |
+
# get filtered and truncated and SEP tokenized doc text
|
186 |
+
combined['doc'] = prep_doc_text(doc, dconfig)
|
187 |
+
|
188 |
+
# get relevancy score as string
|
189 |
+
if dconfig.convert_snli:
|
190 |
+
rel_score = convert_label_snli(rel_score)
|
191 |
+
|
192 |
+
combined['label'] = str(rel_score)
|
193 |
+
|
194 |
+
return combined
|
195 |
+
|
196 |
+
|
197 |
+
def convert_label_snli(label: int) -> int:
|
198 |
+
if label == 2:
|
199 |
+
return 1
|
200 |
+
elif label == 1:
|
201 |
+
return 2
|
202 |
+
return label
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
def downsample_zero(label: str, zero_ct: int, dconfig: DataConfig) -> bool:
|
207 |
+
if dconfig.downsample_zeros_n is not None:
|
208 |
+
if (label == 0) and (random.random() > (dconfig.downsample_zeros_n / zero_ct)):
|
209 |
+
return True
|
210 |
+
return False
|
211 |
+
|
212 |
+
|
213 |
+
def prep_topic_text(topic: Dict[str, Union[List[str], str, float]], dconfig: DataConfig) -> str:
|
214 |
+
topic_text = ' '.join(topic['text_sents'])
|
215 |
+
topic_text = truncate(topic_text, dconfig.max_topic_len)
|
216 |
+
return topic_text
|
217 |
+
|
218 |
+
|
219 |
+
def get_n_crit(crit_list: List[str], dconfig: DataConfig) -> List[str]:
|
220 |
+
if dconfig.first_n_only is not None:
|
221 |
+
crit_list = crit_list[:min(len(crit_list), dconfig.first_n_only)]
|
222 |
+
return crit_list
|
223 |
+
|
224 |
+
|
225 |
+
def prep_doc_text(doc: Dict[str, Union[List[str], str, float]], dconfig: DataConfig) -> str:
|
226 |
+
|
227 |
+
# combine lists of strings into single string
|
228 |
+
doc_inc = ' '.join(get_n_crit(doc['elig_crit']['include_criteria'], dconfig))
|
229 |
+
doc_exc = ' '.join(get_n_crit(doc['elig_crit']['exclude_criteria'], dconfig))
|
230 |
+
|
231 |
+
|
232 |
+
if 'condition' in dconfig.filtered_doc_keys:
|
233 |
+
doc_inc = f"{' '.join(doc['condition'])} {doc_inc}"
|
234 |
+
if dconfig.llm_prep:
|
235 |
+
doc_inc = "Condition: " + doc_inc + ", "
|
236 |
+
|
237 |
+
#truncate criteria separately if in config
|
238 |
+
doc_inc = truncate(doc_inc, dconfig.max_inc_len)
|
239 |
+
doc_exc = truncate(doc_exc, dconfig.max_exc_len)
|
240 |
+
|
241 |
+
|
242 |
+
if dconfig.prepend_elig_gender:
|
243 |
+
doc_inc = f"{doc['elig_gender']} {dconfig.sep} {doc_inc}"
|
244 |
+
if dconfig.llm_prep:
|
245 |
+
doc_inc = "Gender: " + doc_inc + ", "
|
246 |
+
|
247 |
+
if dconfig.prepend_elig_age:
|
248 |
+
if dconfig.llm_prep:
|
249 |
+
doc_inc = f"Trial Doc: A person who is between {doc['elig_min_age']}-{doc['elig_max_age']} years old who meets the following Inclusion Criteria: {doc_inc}"
|
250 |
+
else:
|
251 |
+
doc_inc = f"eligible ages (years): {doc['elig_min_age']}-{doc['elig_max_age']}, {dconfig.sep} {doc_inc}"
|
252 |
+
|
253 |
+
# combine criteria into single string
|
254 |
+
if dconfig.include_only:
|
255 |
+
if dconfig.llm_prep:
|
256 |
+
doc_inc += LLM_END_PROMPT
|
257 |
+
return doc_inc
|
258 |
+
|
259 |
+
if dconfig.llm_prep:
|
260 |
+
return f"{doc_inc} and does not meet these Exclusion Criteria: {doc_exc} {LLM_END_PROMPT}"
|
261 |
+
|
262 |
+
return f"{doc_inc} {dconfig.sep} {doc_exc}"
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
268 |
+
# utility functions
|
269 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
270 |
+
|
271 |
+
def age_match(min_doc_age: float, max_doc_age: float, topic_age: float) -> bool:
|
272 |
+
if topic_age < min_doc_age:
|
273 |
+
return False
|
274 |
+
if topic_age > max_doc_age:
|
275 |
+
return False
|
276 |
+
return True
|
277 |
+
|
278 |
+
def gender_match(doc_gender: str, topic_gender: str) -> bool:
|
279 |
+
if doc_gender == 'All':
|
280 |
+
return True
|
281 |
+
if doc_gender == topic_gender:
|
282 |
+
return True
|
283 |
+
return False
|
284 |
+
|
285 |
+
|
286 |
+
def get_topic_and_rel_path(trec_or_kz: str = 'trec') -> Tuple[str, str]:
|
287 |
+
if trec_or_kz == 'trec':
|
288 |
+
rel_path = ctpaths.TREC_REL_PATH
|
289 |
+
topic_path = ctpaths.TREC_RELLED_TOPIC_PATH
|
290 |
+
else:
|
291 |
+
rel_path = ctpaths.KZ_REL_PATH
|
292 |
+
topic_path = ctpaths.KZ_RELLED_TOPIC_PATH
|
293 |
+
return topic_path, rel_path
|
294 |
+
|
295 |
+
|
296 |
+
def get_doc_and_topic_mappings(all_qrelled_docs: Set[str], doc_tuples: List[Tuple[str, str]], topic_path: str) -> Tuple[Dict[str, Dict[str, str]], Dict[str, Dict[str, str]]]:
|
297 |
+
"""
|
298 |
+
desc: get mappings of doc ids to doc dicts and topic ids to topic dicts
|
299 |
+
"""
|
300 |
+
|
301 |
+
# get all processed topics
|
302 |
+
id2topic = {t['id']:t for t in get_processed_data(topic_path)}
|
303 |
+
|
304 |
+
# get all processed docs
|
305 |
+
id2doc = dict()
|
306 |
+
for _, processed_doc_path in doc_tuples:
|
307 |
+
print(f"getting docs from {processed_doc_path}")
|
308 |
+
for doc in get_processed_data(processed_doc_path):
|
309 |
+
if doc['id'] in all_qrelled_docs:
|
310 |
+
id2doc[doc['id']] = doc
|
311 |
+
|
312 |
+
return id2doc, id2topic
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == '__main__':
|
316 |
+
# proc_docs_and_topics('kz')
|
317 |
+
# eda.explore_trec_data(part=2, rand_print=0.001) # select part 1-5 (~70k docs per part)
|
318 |
+
# eda.explore_kz_data(rand_print=0.00001) # all in one file (~200k docs)
|
319 |
+
|
320 |
+
# example config:
|
321 |
+
# class DataConfig(NamedTuple):
|
322 |
+
# save_path: str
|
323 |
+
# trec_or_kz: str = 'trec'
|
324 |
+
# filtered_topic_keys: Set[str] = {'id', 'text_sents', 'age', 'gender'}
|
325 |
+
# filtered_doc_keys: Set[str] = {'id', 'elig_min_age', 'elig_max_age', 'elig_gender', 'condition', 'elig_crit'}
|
326 |
+
# max_topic_len: Optional[int] = None
|
327 |
+
# max_inc_len: Optional[int] = None
|
328 |
+
# max_exc_len: Optional[int] = None
|
329 |
+
# prepend_elig_age: bool = True
|
330 |
+
# prepend_elig_gender: bool = True
|
331 |
+
# include_only: bool = False
|
332 |
+
# downsample_zeros_n: Optional[int] = None
|
333 |
+
# sep: str = '[SEP]'
|
334 |
+
# llm_prep: bool = False
|
335 |
+
# first_n_only: Optional[int] = None
|
336 |
+
# convert_snli: bool = False
|
337 |
+
# infer_category_model: Optional[str] = None
|
338 |
+
|
339 |
+
dconfig = DataConfig(
|
340 |
+
trec_or_kz='trec',
|
341 |
+
save_path=ctpaths.TREC_ML_PATH, # make sure to change this!
|
342 |
+
sep='',
|
343 |
+
first_n_only=10,
|
344 |
+
max_topic_len=200,
|
345 |
+
llm_prep=False,
|
346 |
+
prepend_elig_age=True,
|
347 |
+
prepend_elig_gender=False
|
348 |
+
)
|
349 |
+
prep_fine_tuning_dataset(dconfig)
|
350 |
+
#eda.explore_prepped(ctpaths.TREC_KZ_PATH)
|
351 |
+
|
ctmatch/dataprep.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
# external imports
|
5 |
+
from datasets import Dataset, load_dataset, ClassLabel, Features, Value
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
# package tools
|
11 |
+
from .utils.ctmatch_utils import train_test_val_split, get_processed_data, get_test_rels
|
12 |
+
from .pipeconfig import PipeConfig
|
13 |
+
|
14 |
+
|
15 |
+
# path to ctmatch dataset on HF hub
|
16 |
+
CTMATCH_CLASSIFICATION_DATASET_ROOT = "semaj83/ctmatch_classification"
|
17 |
+
CTMATCH_IR_DATASET_ROOT = "semaj83/ctmatch_ir"
|
18 |
+
CLASSIFIER_DATA_PATH = "combined_classifier_data.jsonl"
|
19 |
+
DOC_TEXTS_PATH = "doc_texts.txt"
|
20 |
+
DOC_CATEGORIES_VEC_PATH = "doc_categories.txt"
|
21 |
+
DOC_EMBEDDINGS_VEC_PATH = "doc_embeddings.txt"
|
22 |
+
INDEX2DOCID_PATH = "index2docid.txt"
|
23 |
+
|
24 |
+
|
25 |
+
SUPPORTED_LMS = [
|
26 |
+
'roberta-large', 'cross-encoder/nli-roberta-base',
|
27 |
+
'microsoft/biogpt', 'allenai/scibert_scivocab_uncased',
|
28 |
+
'facebook/bart-large', 'gpt2',
|
29 |
+
'semaj83/scibert_finetuned_ctmatch', 'semaj83/scibert_finetuned_pruned_ctmatch'
|
30 |
+
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
class DataPrep:
|
35 |
+
# multiple 'datasets' need to be prepared for the pipeline
|
36 |
+
# 1. the dataset for the classifier model triplets and a dataframe, ~ 25k rows
|
37 |
+
# 2. the dataset for the category model, every doc ~200k rows
|
38 |
+
# 3. the dataset for the embedding model, every doc < 200k rows
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def __init__(self, pipe_config: PipeConfig) -> None:
|
43 |
+
self.pipe_config = pipe_config
|
44 |
+
self.classifier_tokenizer = self.get_classifier_tokenizer()
|
45 |
+
self.ct_dataset = None
|
46 |
+
self.ct_train_dataset_df = None
|
47 |
+
self.index2docid = None
|
48 |
+
self.doc_embeddings_df = None
|
49 |
+
self.doc_categories_df = None
|
50 |
+
|
51 |
+
if pipe_config.ir_setup:
|
52 |
+
self.load_ir_data()
|
53 |
+
else:
|
54 |
+
self.load_classifier_data()
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def get_classifier_tokenizer(self):
|
60 |
+
model_checkpoint = self.pipe_config.classifier_model_checkpoint
|
61 |
+
if model_checkpoint not in SUPPORTED_LMS:
|
62 |
+
raise ValueError(f"Model checkpoint {model_checkpoint} not supported. Please use one of {SUPPORTED_LMS}")
|
63 |
+
if 'scibert' in model_checkpoint:
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', use_fast=True)
|
65 |
+
else:
|
66 |
+
tokenizer = AutoTokenizer.from_pretrained(self.pipe_config.classifier_model_checkpoint)
|
67 |
+
if self.pipe_config.classifier_model_checkpoint == 'gpt2':
|
68 |
+
tokenizer.pad_token = tokenizer.eos_token
|
69 |
+
return tokenizer
|
70 |
+
|
71 |
+
|
72 |
+
# ------------------ Classifier Data Loading ------------------ #
|
73 |
+
def load_classifier_data(self) -> Dataset:
|
74 |
+
self.ct_dataset = load_dataset(CTMATCH_CLASSIFICATION_DATASET_ROOT, data_files=CLASSIFIER_DATA_PATH)
|
75 |
+
self.ct_dataset = train_test_val_split(self.ct_dataset, self.pipe_config.splits, self.pipe_config.seed)
|
76 |
+
self.add_features()
|
77 |
+
self.tokenize_dataset()
|
78 |
+
self.ct_dataset = self.ct_dataset.rename_column("label", "labels")
|
79 |
+
# self.ct_dataset = self.ct_dataset.rename_column("topic", "sentence1")
|
80 |
+
# self.ct_dataset = self.ct_dataset.rename_column("doc", "sentence2")
|
81 |
+
self.ct_dataset.set_format(type='torch', columns=['doc', 'labels', 'topic', 'input_ids', 'attention_mask'])
|
82 |
+
if not self.pipe_config.use_trainer:
|
83 |
+
self.ct_dataset = self.ct_dataset.remove_columns(['doc', 'topic']) # removing labels for next-token prediction...
|
84 |
+
|
85 |
+
self.ct_train_dataset_df = self.ct_dataset['train'].remove_columns(['input_ids', 'attention_mask', 'token_type_ids']).to_pandas()
|
86 |
+
|
87 |
+
return self.ct_dataset
|
88 |
+
|
89 |
+
|
90 |
+
def add_features(self) -> None:
|
91 |
+
if self.pipe_config.convert_snli:
|
92 |
+
names = ['contradiction', 'entailment', 'neutral']
|
93 |
+
else:
|
94 |
+
names = ["not_relevant", "partially_relevant", "relevant"]
|
95 |
+
|
96 |
+
features = Features({
|
97 |
+
'doc': Value(dtype='string', id=None),
|
98 |
+
'label': ClassLabel(names=names),
|
99 |
+
'topic': Value(dtype='string', id=None)
|
100 |
+
})
|
101 |
+
self.ct_dataset["train"] = self.ct_dataset["train"].map(lambda x: x, batched=True, features=features)
|
102 |
+
self.ct_dataset["test"] = self.ct_dataset["test"].map(lambda x: x, batched=True, features=features)
|
103 |
+
self.ct_dataset["validation"] = self.ct_dataset["validation"].map(lambda x: x, batched=True, features=features)
|
104 |
+
|
105 |
+
|
106 |
+
def tokenize_function(self, examples):
|
107 |
+
return self.classifier_tokenizer(
|
108 |
+
examples["topic"], examples["doc"],
|
109 |
+
truncation=self.pipe_config.truncation,
|
110 |
+
padding=self.pipe_config.padding,
|
111 |
+
max_length=self.pipe_config.max_length
|
112 |
+
)
|
113 |
+
|
114 |
+
def tokenize_dataset(self):
|
115 |
+
self.ct_dataset = self.ct_dataset.map(self.tokenize_function, batched=True)
|
116 |
+
|
117 |
+
|
118 |
+
def get_category_data(self, vectorize=True):
|
119 |
+
category_data = dict()
|
120 |
+
sorted_cat_keys = None
|
121 |
+
for cdata in get_processed_data(self.pipe_config.category_path):
|
122 |
+
|
123 |
+
# cdata = {<nct_id>: {cat1: float1, cat2: float2...}}
|
124 |
+
cdata_id, cdata_dict = list(cdata.items())[0]
|
125 |
+
if sorted_cat_keys is None:
|
126 |
+
sorted_cat_keys = sorted(cdata_dict.keys())
|
127 |
+
|
128 |
+
if vectorize:
|
129 |
+
cat_vec = np.asarray([cdata_dict[k] for k in sorted_cat_keys])
|
130 |
+
else:
|
131 |
+
cat_vec = cdata_dict
|
132 |
+
|
133 |
+
category_data[cdata_id] = cat_vec
|
134 |
+
return category_data
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
# ------------------ IR Data Loading ------------------ #
|
139 |
+
def process_ir_data_from_hf(self, ds_path, is_text: bool = False):
|
140 |
+
ds = load_dataset(CTMATCH_IR_DATASET_ROOT, data_files=ds_path)
|
141 |
+
if is_text:
|
142 |
+
return pd.DataFrame(ds['train'])
|
143 |
+
|
144 |
+
arrays = [np.asarray(a['text'].split(','), dtype=float) for a in ds['train']]
|
145 |
+
return pd.DataFrame(arrays)
|
146 |
+
|
147 |
+
def load_ir_data(self) -> None:
|
148 |
+
self.index2docid = self.process_ir_data_from_hf(INDEX2DOCID_PATH, is_text=True)
|
149 |
+
self.doc_embeddings_df = self.process_ir_data_from_hf(DOC_EMBEDDINGS_VEC_PATH)
|
150 |
+
self.doc_categories_df = self.process_ir_data_from_hf(DOC_CATEGORIES_VEC_PATH)
|
151 |
+
self.doc_texts_df = self.process_ir_data_from_hf(DOC_TEXTS_PATH, is_text=True)
|
152 |
+
|
ctmatch/eda.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, NamedTuple, Tuple
|
3 |
+
from utils.ctmatch_utils import get_processed_data
|
4 |
+
from collections import defaultdict
|
5 |
+
import ct_data_paths
|
6 |
+
import random
|
7 |
+
|
8 |
+
from ctproc.scripts.vis_scripts import (
|
9 |
+
analyze_test_rels
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class ExplorePaths(NamedTuple):
|
14 |
+
doc_path: str
|
15 |
+
topic_path: str
|
16 |
+
rel_path: str
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
21 |
+
# EDA functions
|
22 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
23 |
+
|
24 |
+
def explore_kz_data(rand_print: float = 0.001) -> None:
|
25 |
+
kz_data_paths = ExplorePaths(
|
26 |
+
rel_path = ct_data_paths.KZ_REL_PATH,
|
27 |
+
doc_path = ct_data_paths.KZ_PROCESSED_DOC_PATH,
|
28 |
+
topic_path = ct_data_paths.KZ_RELLED_TOPIC_PATH
|
29 |
+
)
|
30 |
+
|
31 |
+
explore_data(kz_data_paths, rand_print=rand_print)
|
32 |
+
|
33 |
+
|
34 |
+
def explore_trec_data(part: int = 1, rand_print: float = 0.001) -> None:
|
35 |
+
# post processing analysis
|
36 |
+
trec_data_paths = ExplorePaths(
|
37 |
+
rel_path = ct_data_paths.TREC_REL_PATH,
|
38 |
+
doc_path = f'/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part{part}.jsonl',
|
39 |
+
topic_path = ct_data_paths.TREC_RELLED_TOPIC_PATH
|
40 |
+
)
|
41 |
+
|
42 |
+
explore_data(trec_data_paths, rand_print=rand_print)
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def explore_data(data_paths: ct_data_paths.ExplorePaths, rand_print: float) -> None:
|
47 |
+
|
48 |
+
# process relevancy judgements
|
49 |
+
type_dict, rel_dict, all_qrelled_docs = analyze_test_rels(data_paths.rel_path)
|
50 |
+
|
51 |
+
# get processed topics
|
52 |
+
id2topic = {t['id']:t for t in get_processed_data(data_paths.topic_path)}
|
53 |
+
print(f"number of processed topics: {len(id2topic)}")
|
54 |
+
|
55 |
+
# get relevant processed docs
|
56 |
+
id2docs = {doc['id']:doc for doc in get_processed_data(data_paths.doc_path, get_only=all_qrelled_docs)}
|
57 |
+
print(f"number of relevant processed docs: {len(id2docs)}")
|
58 |
+
|
59 |
+
explore_pairs(id2topic, id2docs, rel_dict, max_print=1000, rand_print=rand_print)
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def explore_pairs(id2topic: Dict[str, Dict[str, str]], id2docs: Dict[str, Dict[str, str]], rel_dict: Dict[str, Dict[str, str]], rand_print: float, max_print:int = 100000) -> None:
|
66 |
+
rel_scores = defaultdict(int)
|
67 |
+
age_mismatches, gender_mismatches = 0, 0
|
68 |
+
for pt_id, topic in id2topic.items():
|
69 |
+
for doc_id in rel_dict[pt_id]:
|
70 |
+
if doc_id in id2docs:
|
71 |
+
rel_score = rel_dict[pt_id][doc_id]
|
72 |
+
rel_scores[rel_score] += 1
|
73 |
+
if rel_score == 2:
|
74 |
+
age_mismatches, gender_mismatches = check_match(
|
75 |
+
topic = topic,
|
76 |
+
doc = id2docs[doc_id],
|
77 |
+
rel_score = rel_score,
|
78 |
+
age_mismatches = age_mismatches,
|
79 |
+
gender_mismatches = gender_mismatches
|
80 |
+
)
|
81 |
+
|
82 |
+
if random.random() < rand_print:
|
83 |
+
print_pair(topic, id2docs[doc_id], rel_score, marker='%')
|
84 |
+
|
85 |
+
print(rel_scores.items())
|
86 |
+
print(f"{age_mismatches=}, {gender_mismatches=}")
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def check_match(topic: Dict[str, str], doc: Dict[str, str], rel_score: int, age_mismatches: int, gender_mismatches: int) -> Tuple[int, int]:
|
92 |
+
age_matches = age_match(doc['elig_min_age'], doc['elig_max_age'], topic['age'])
|
93 |
+
if not age_matches:
|
94 |
+
#print_pair(topic, doc, rel_score)
|
95 |
+
age_mismatches += 1
|
96 |
+
|
97 |
+
gender_matches = gender_match(doc['elig_gender'], topic['gender'])
|
98 |
+
if not gender_matches:
|
99 |
+
#print_pair(topic, doc, rel_score)
|
100 |
+
gender_mismatches += 1
|
101 |
+
|
102 |
+
return age_mismatches, gender_mismatches
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def print_pair(topic: Dict[str, str], doc: Dict[str, str], rel_score: int, marker: str = '*') -> None:
|
107 |
+
print(marker*200)
|
108 |
+
print(f"topic id: {topic['id']}, nct_id: {doc['id']}, rel score: {rel_score}")
|
109 |
+
print(f"topic info: \nage: {topic['age']}, gender: {topic['gender']}")
|
110 |
+
print(topic['raw_text'])
|
111 |
+
print(f"doc info: gender: {doc['elig_gender']}, min age: {doc['elig_min_age']}, max age: {doc['elig_max_age']}")
|
112 |
+
print(doc['elig_crit']['raw_text'])
|
113 |
+
print(marker*200)
|
114 |
+
print()
|
ctmatch/evaluator.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
from typing import List, NamedTuple, Optional, Tuple, Union
|
4 |
+
|
5 |
+
from .utils.eval_utils import (
|
6 |
+
calc_first_positive_rank, calc_f1, get_kz_topic2text, get_trec_topic2text
|
7 |
+
)
|
8 |
+
from .pipeconfig import PipeConfig
|
9 |
+
from .match import CTMatch
|
10 |
+
from pathlib import Path
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class EvaluatorConfig(NamedTuple):
|
18 |
+
rel_paths: List[str]
|
19 |
+
trec_topic_path: Union[Path, str] = None
|
20 |
+
kz_topic_path: Union[Path, str] = None
|
21 |
+
max_topics: int = 200
|
22 |
+
openai_api_key: Optional[str] = None
|
23 |
+
filters: Optional[List[str]] = None
|
24 |
+
sanity_check_ids: Optional[List[str]] = None
|
25 |
+
|
26 |
+
|
27 |
+
class Evaluator:
|
28 |
+
def __init__(self, eval_config: EvaluatorConfig) -> None:
|
29 |
+
self.rel_paths: List[str] = eval_config.rel_paths
|
30 |
+
self.trec_topic_path: Union[Path, str] = eval_config.trec_topic_path
|
31 |
+
self.kz_topic_path: Union[Path, str] = eval_config.kz_topic_path
|
32 |
+
|
33 |
+
self.rel_dict: dict = None
|
34 |
+
self.topicid2text: dict = None
|
35 |
+
self.ctm = None
|
36 |
+
self.openai_api_key = eval_config.openai_api_key
|
37 |
+
self.filters = eval_config.filters
|
38 |
+
self.sanity_check_ids = eval_config.sanity_check_ids
|
39 |
+
|
40 |
+
assert self.rel_paths is not None, "paths to relevancy judgments must be set in pipe_config if pipe_config.evaluate=True"
|
41 |
+
assert ((self.trec_topic_path is not None) or (self.kz_topic_path is not None)), "at least one of trec_topic_path or kz_topic_path) must be set as pipe_config.evaluate=True"
|
42 |
+
|
43 |
+
self.setup()
|
44 |
+
|
45 |
+
self.max_topics: int = len(self.topicid2text) if eval_config.max_topics is None else min(len(self.topicid2text), eval_config.max_topics)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
def get_combined_rel_dict(self, rel_paths: List[str]) -> dict:
|
50 |
+
combined_rel_dict = dict()
|
51 |
+
for rel_path in rel_paths:
|
52 |
+
with open(rel_path, 'r') as f:
|
53 |
+
for line in f.readlines():
|
54 |
+
topic_id, _, doc_id, rel = line.split()
|
55 |
+
if topic_id not in combined_rel_dict:
|
56 |
+
combined_rel_dict[topic_id] = dict()
|
57 |
+
combined_rel_dict[topic_id][doc_id] = int(rel)
|
58 |
+
return combined_rel_dict
|
59 |
+
|
60 |
+
def setup(self):
|
61 |
+
self.rel_dict = self.get_combined_rel_dict(self.rel_paths)
|
62 |
+
self.topicid2text = dict()
|
63 |
+
if self.kz_topic_path is not None:
|
64 |
+
self.topicid2text = get_kz_topic2text(self.kz_topic_path)
|
65 |
+
|
66 |
+
if self.trec_topic_path is not None:
|
67 |
+
self.topicid2text.update(get_trec_topic2text(self.trec_topic_path))
|
68 |
+
|
69 |
+
# loads all remaining needed datasets into memory
|
70 |
+
pipe_config = PipeConfig(
|
71 |
+
openai_api_key=self.openai_api_key,
|
72 |
+
ir_setup=True,
|
73 |
+
filters=self.filters
|
74 |
+
)
|
75 |
+
self.ctm = CTMatch(pipe_config=pipe_config)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
def evaluate(self):
|
80 |
+
"""
|
81 |
+
desc: run the pipeline over every topic and associated labelled set of documents,
|
82 |
+
and compute the mean mrr over all topics (how far down to the first relevant document)
|
83 |
+
"""
|
84 |
+
frrs, f1s, fprs = [], [], []
|
85 |
+
for topic_id, topic_text in tqdm(list(self.topicid2text.items())[:self.max_topics]):
|
86 |
+
|
87 |
+
if topic_id not in self.rel_dict:
|
88 |
+
# can't evaluate with no judgments
|
89 |
+
continue
|
90 |
+
|
91 |
+
doc_ids = list(self.rel_dict[topic_id].keys())
|
92 |
+
logger.info(f"number of ranked docs: {len(doc_ids)}")
|
93 |
+
doc_set = self.get_indexes_from_ids(doc_ids)
|
94 |
+
|
95 |
+
# run IR pipeline on set of indexes corresponding to labelled doc_ids
|
96 |
+
ranked_pairs = self.ctm.match_pipeline(topic_text, doc_set=doc_set)
|
97 |
+
|
98 |
+
# get NCTIDs from ranking
|
99 |
+
ranked_ids = [nct_id for nct_id, doc_text in ranked_pairs]
|
100 |
+
|
101 |
+
# calculate metrics
|
102 |
+
fpr, frr = calc_first_positive_rank(ranked_ids, self.rel_dict[topic_id])
|
103 |
+
f1 = calc_f1(ranked_ids, self.rel_dict[topic_id])
|
104 |
+
|
105 |
+
if self.sanity_check_ids is not None and (topic_id in self.sanity_check_ids):
|
106 |
+
self.sanity_check(topic_id, topic_text, ranked_pairs, self.rel_dict[topic_id])
|
107 |
+
|
108 |
+
fprs.append(fpr)
|
109 |
+
frrs.append(frr)
|
110 |
+
f1s.append(f1)
|
111 |
+
|
112 |
+
mean_fpr = sum(fprs)/len(fprs)
|
113 |
+
std_fpr = np.std(fprs)
|
114 |
+
mean_frr = sum(frrs)/len(frrs)
|
115 |
+
std_frr = np.std(frrs)
|
116 |
+
mean_f1 = sum(f1s)/len(f1s)
|
117 |
+
std_f1 = np.std(f1s)
|
118 |
+
|
119 |
+
return {
|
120 |
+
"mean_fpr":mean_fpr, "std_fpr":std_fpr,
|
121 |
+
"mean_frr":mean_frr, "std_frr":std_frr,
|
122 |
+
"mean_f1":mean_f1, "std_f1":std_f1
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
def get_indexes_from_ids(self, doc_id_set: List[str]) -> List[int]:
|
127 |
+
"""
|
128 |
+
desc: get the indexes of the documents in doc_id_set in the order they appear in the ranking
|
129 |
+
returns: list of indexes
|
130 |
+
"""
|
131 |
+
doc_indices = []
|
132 |
+
for doc_id in doc_id_set:
|
133 |
+
index_row = np.where(self.ctm.data.index2docid['text'] == doc_id)
|
134 |
+
if len(index_row[0]) == 0:
|
135 |
+
continue
|
136 |
+
doc_indices.append(index_row[0][0])
|
137 |
+
return doc_indices
|
138 |
+
|
139 |
+
def sanity_check(self, topic_id, topic_text, ranked_pairs: List[Tuple[str, str]], rel_dict) -> None:
|
140 |
+
logger.info(f"{topic_id=} {topic_text}")
|
141 |
+
for doc_id, doc_text in ranked_pairs:
|
142 |
+
rel_score = rel_dict[doc_id]
|
143 |
+
logger.info(rel_score, doc_id, doc_text)
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
ctmatch/match.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
# external imports
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from transformers import pipeline
|
9 |
+
from numpy.linalg import norm
|
10 |
+
from pathlib import Path
|
11 |
+
from sklearn import svm
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import json
|
15 |
+
|
16 |
+
|
17 |
+
# package tools
|
18 |
+
from .models.classifier_model import ClassifierModel
|
19 |
+
from .utils.ctmatch_utils import get_processed_data, exclusive_argmax
|
20 |
+
from .models.gen_model import GenModel
|
21 |
+
from .pipeconfig import PipeConfig
|
22 |
+
from .pipetopic import PipeTopic
|
23 |
+
from .dataprep import DataPrep
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
logger.setLevel(logging.INFO)
|
28 |
+
|
29 |
+
|
30 |
+
CT_CATEGORIES = [
|
31 |
+
"pulmonary", "cardiac", "gastrointestinal", "renal", "psychological", "genetic", "pediatric",
|
32 |
+
"neurological", "cancer", "reproductive", "endocrine", "infection", "healthy", "other"
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
GEN_INIT_PROMPT = "I will give you a patient description and a set of clinical trial documents. Each document will have a NCTID. I would like you to return the set of NCTIDs ranked from most to least relevant for patient in the description.\n"
|
37 |
+
|
38 |
+
|
39 |
+
class CTMatch:
|
40 |
+
|
41 |
+
def __init__(self, pipe_config: Optional[PipeConfig] = None) -> None:
|
42 |
+
# default to model config with full ir setup
|
43 |
+
self.pipe_config = pipe_config if pipe_config is not None else PipeConfig(ir_setup=True)
|
44 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
45 |
+
self.data = DataPrep(self.pipe_config)
|
46 |
+
self.classifier_model = ClassifierModel(self.pipe_config, self.data, self.device)
|
47 |
+
self.embedding_model = SentenceTransformer(self.pipe_config.embedding_model_checkpoint)
|
48 |
+
self.gen_model = GenModel(self.pipe_config)
|
49 |
+
self.category_model = None
|
50 |
+
self.filters: Optional[List[str]] = pipe_config.filters
|
51 |
+
|
52 |
+
# filter params
|
53 |
+
self.sim_top_n = 10000
|
54 |
+
self.svm_top_n = 100
|
55 |
+
self.classifier_top_n = 50
|
56 |
+
self.gen_top_n = 10
|
57 |
+
|
58 |
+
|
59 |
+
# main api method
|
60 |
+
def match_pipeline(self, topic: str, top_k: int = 10, doc_set: Optional[List[int]] = None) -> List[str]:
|
61 |
+
|
62 |
+
if doc_set is None:
|
63 |
+
# start off will all doc indexes
|
64 |
+
doc_set = [i for i in range(len(self.data.index2docid))]
|
65 |
+
else:
|
66 |
+
self.reset_filter_params(len(doc_set))
|
67 |
+
|
68 |
+
# get topic representations for pipeline filters
|
69 |
+
pipe_topic = self.get_pipe_topic(topic)
|
70 |
+
|
71 |
+
if self.filters is None or ('sim' in self.filters):
|
72 |
+
# first filter, category + embedding similarity
|
73 |
+
doc_set = self.sim_filter(pipe_topic, doc_set, top_n=self.sim_top_n)
|
74 |
+
|
75 |
+
if self.filters is None or ('svm' in self.filters):
|
76 |
+
# second filter, SVM
|
77 |
+
doc_set = self.svm_filter(pipe_topic, doc_set, top_n=self.svm_top_n)
|
78 |
+
|
79 |
+
if self.filters is None or ('classifier' in self.filters):
|
80 |
+
# third filter, classifier-LM (reranking)
|
81 |
+
doc_set = self.classifier_filter(pipe_topic, doc_set, top_n=self.classifier_top_n)
|
82 |
+
|
83 |
+
if self.filters is None or ('gen' in self.filters):
|
84 |
+
# fourth filter, generative-LM
|
85 |
+
doc_set = self.gen_filter(pipe_topic, doc_set, top_n=top_k)
|
86 |
+
|
87 |
+
return self.get_return_data(doc_set[:min(top_k, len(doc_set))])
|
88 |
+
|
89 |
+
|
90 |
+
def reset_filter_params(self, val: int) -> None:
|
91 |
+
self.sim_top_n = self.svm_top_n = self.classifier_top_n = self.gen_top_n = val
|
92 |
+
|
93 |
+
|
94 |
+
# ------------------------------------------------------------------------------------------ #
|
95 |
+
# filtering methods
|
96 |
+
# ------------------------------------------------------------------------------------------ #
|
97 |
+
|
98 |
+
def sim_filter(self, pipe_topic: PipeTopic, doc_set: List[int], top_n: int) -> List[int]:
|
99 |
+
"""
|
100 |
+
filter documents by similarity to topic
|
101 |
+
doing this with loop and cosine similarity instead of linear kernel because of memory issues
|
102 |
+
"""
|
103 |
+
logger.info(f"running sim filter on {len(doc_set)} docs")
|
104 |
+
|
105 |
+
topic_cat_vec = exclusive_argmax(pipe_topic.category_vec)
|
106 |
+
norm_topic_emb = norm(pipe_topic.embedding_vec)
|
107 |
+
cosine_dists = []
|
108 |
+
for doc_idx in doc_set:
|
109 |
+
doc_cat_vec = self.redist_other_category(self.data.doc_categories_df.iloc[doc_idx].values)
|
110 |
+
|
111 |
+
# only consider strongest predicted category
|
112 |
+
doc_cat_vec = exclusive_argmax(doc_cat_vec)
|
113 |
+
doc_emb_vec = self.data.doc_embeddings_df.iloc[doc_idx].values
|
114 |
+
|
115 |
+
topic_argmax = np.argmax(topic_cat_vec)
|
116 |
+
doc_argmax = np.argmax(doc_cat_vec)
|
117 |
+
cat_dist = 0. if (topic_argmax == doc_argmax) else 1.
|
118 |
+
emb_dist = np.dot(pipe_topic.embedding_vec, doc_emb_vec) / (norm_topic_emb * norm(doc_emb_vec))
|
119 |
+
combined_dist = cat_dist + emb_dist
|
120 |
+
cosine_dists.append(combined_dist)
|
121 |
+
|
122 |
+
sorted_indices = list(np.argsort(cosine_dists))[:min(len(doc_set), top_n)]
|
123 |
+
|
124 |
+
# return top n doc indices by combined similiarity, biggest to smallest
|
125 |
+
return [doc_set[i] for i in sorted_indices]
|
126 |
+
|
127 |
+
|
128 |
+
def svm_filter(self, topic: PipeTopic, doc_set: List[int], top_n: int) -> List[int]:
|
129 |
+
"""
|
130 |
+
filter documents by training an SVM on topic and doc embeddings
|
131 |
+
"""
|
132 |
+
logger.info(f"running svm filter on {len(doc_set)} documents")
|
133 |
+
|
134 |
+
# build training data and prediction vector of single positive class for SVM
|
135 |
+
topic_embedding_vec = topic.embedding_vec[np.newaxis, :]
|
136 |
+
x = np.concatenate([topic_embedding_vec, self.data.doc_embeddings_df.iloc[doc_set].values], axis=0)
|
137 |
+
y = np.zeros(len(doc_set) + 1)
|
138 |
+
y[0] = 1
|
139 |
+
|
140 |
+
# define and fit SVM
|
141 |
+
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
|
142 |
+
clf.fit(x, y)
|
143 |
+
|
144 |
+
# infer for similarities
|
145 |
+
similarities = clf.decision_function(x)
|
146 |
+
|
147 |
+
# get top n doc indices by similiarity, biggest to smallest
|
148 |
+
result = list(np.argsort(-similarities)[:min(len(doc_set) + 1, top_n + 1)])
|
149 |
+
|
150 |
+
# remove topic from result
|
151 |
+
result.remove(0)
|
152 |
+
|
153 |
+
# indexes got shifted by 1 because topic was included in doc_set
|
154 |
+
return [doc_set[(r - 1)] for r in result]
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def classifier_filter(self, pipe_topic: PipeTopic, doc_set: List[int], top_n: int) -> List[int]:
|
159 |
+
"""
|
160 |
+
filter documents by classifier no relevance prediction
|
161 |
+
"""
|
162 |
+
logger.info(f"running classifier filter on {len(doc_set)} documents")
|
163 |
+
|
164 |
+
# get doc texts
|
165 |
+
doc_texts = [v[0] for v in self.data.doc_texts_df.iloc[doc_set].values]
|
166 |
+
|
167 |
+
# sort by reverse irrelevant prediction
|
168 |
+
neg_predictions = np.asarray([p[0] for p in self.classifier_model.batch_inference(pipe_topic.topic_text, doc_texts, return_preds=True)])
|
169 |
+
|
170 |
+
# return top n doc indices by classifier, biggest to smallest
|
171 |
+
sorted_indices = list(np.argsort(neg_predictions)[:min(len(doc_set), top_n)])
|
172 |
+
return [doc_set[i] for i in sorted_indices]
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
def gen_filter(self, topic: PipeTopic, doc_set: List[int], top_n: int = 10) -> List[int]:
|
177 |
+
"""
|
178 |
+
gen model supplies a ranking of remaming docs by evaluating the pairs of topic and doc texts
|
179 |
+
|
180 |
+
in order to overcome the context length limitation, we need to do a kind of left-binary search over multiple
|
181 |
+
prompts to arrive at a ranking that meets the number of documents requirement (top_n)
|
182 |
+
|
183 |
+
may take a few minutes to run through all queries and subqueries depending on size of doc_set
|
184 |
+
|
185 |
+
"""
|
186 |
+
logger.info(f"running gen filter on {len(doc_set)} documents")
|
187 |
+
|
188 |
+
assert top_n > 0, "top_n must be greater than 0"
|
189 |
+
|
190 |
+
ranked_docs = doc_set
|
191 |
+
iters = 0
|
192 |
+
while (len(ranked_docs) > top_n) and (iters < 10) and (len(ranked_docs) // 2 > top_n):
|
193 |
+
query_prompts = self.get_subqueries(topic, ranked_docs)
|
194 |
+
|
195 |
+
logger.info(f"calling gen model on {len(query_prompts)} subqueries")
|
196 |
+
|
197 |
+
# get gen model response for each query_prompt
|
198 |
+
subrankings = []
|
199 |
+
for prompt in query_prompts:
|
200 |
+
subrank = self.gen_model.gen_response(prompt)
|
201 |
+
|
202 |
+
# keep the top half of each subranking
|
203 |
+
subrankings.extend(subrank[:len(subrank) // 2])
|
204 |
+
|
205 |
+
ranked_docs = subrankings
|
206 |
+
iters += 1
|
207 |
+
|
208 |
+
return ranked_docs[:min(len(ranked_docs), top_n)]
|
209 |
+
|
210 |
+
# ------------------------------------------------------------------------------------------ #
|
211 |
+
# filter helper methods
|
212 |
+
# ------------------------------------------------------------------------------------------ #
|
213 |
+
|
214 |
+
def get_pipe_topic(self, topic):
|
215 |
+
pipe_topic = PipeTopic(
|
216 |
+
topic_text=topic,
|
217 |
+
embedding_vec=self.get_embeddings([topic])[0], # 1 x embedding_dim (default=384)
|
218 |
+
category_vec=self.get_categories(topic) # 1 x 14
|
219 |
+
)
|
220 |
+
return pipe_topic
|
221 |
+
|
222 |
+
|
223 |
+
def get_embeddings(self, texts: List[str]) -> List[float]:
|
224 |
+
return self.embedding_model.encode(texts)
|
225 |
+
|
226 |
+
def get_categories(self, text: str) -> str:
|
227 |
+
if self.category_model is None:
|
228 |
+
self.category_model = pipeline(
|
229 |
+
'zero-shot-classification',
|
230 |
+
model=self.pipe_config.category_model_checkpoint,
|
231 |
+
device=0
|
232 |
+
)
|
233 |
+
output = self.category_model(text, candidate_labels=CT_CATEGORIES)
|
234 |
+
score_dict = {output['labels'][i]:output['scores'][i] for i in range(len(output['labels']))}
|
235 |
+
|
236 |
+
# to be consistent with doc category vecs
|
237 |
+
sorted_keys = sorted(score_dict.keys())
|
238 |
+
return self.redist_other_category(np.array([score_dict[k] for k in sorted_keys]))
|
239 |
+
|
240 |
+
def redist_other_category(self, category_vec: np.ndarray, other_dim:int = 8) -> np.ndarray:
|
241 |
+
"""
|
242 |
+
redistribute 'other' category weight to all other categories
|
243 |
+
"""
|
244 |
+
other_wt = category_vec[other_dim]
|
245 |
+
other_wt_dist = other_wt / (len(category_vec) - 1)
|
246 |
+
redist_cat_vec = category_vec + other_wt_dist
|
247 |
+
redist_cat_vec[other_dim] = 0
|
248 |
+
return redist_cat_vec
|
249 |
+
|
250 |
+
|
251 |
+
def get_gen_query_prompt(self, topic: PipeTopic, doc_set: List[int]) -> str:
|
252 |
+
query_prompt = f"{GEN_INIT_PROMPT}Patient description: {topic.topic_text}\n"
|
253 |
+
|
254 |
+
for i, doc_text in enumerate(self.data.doc_texts_df.iloc[doc_set].values):
|
255 |
+
query_prompt += f"NCTID: {doc_set[i]}, "
|
256 |
+
query_prompt += f"Eligbility Criteria: {doc_text[0]}\n"
|
257 |
+
|
258 |
+
# not really token length bc not tokenized yet but close enough if we undershoot
|
259 |
+
prompt_len = len(query_prompt.split())
|
260 |
+
if prompt_len > self.pipe_config.max_query_length:
|
261 |
+
break
|
262 |
+
|
263 |
+
return query_prompt, i
|
264 |
+
|
265 |
+
|
266 |
+
def get_subqueries(self, topic: PipeTopic, doc_set: List[int]) -> List[str]:
|
267 |
+
query_prompts = []
|
268 |
+
i = 0
|
269 |
+
while i < len(doc_set) - 1:
|
270 |
+
|
271 |
+
# break the querying over remaining doc set into multiple prompts
|
272 |
+
query_prompt, used_i = self.get_gen_query_prompt(topic, doc_set[i:])
|
273 |
+
query_prompts.append(query_prompt)
|
274 |
+
i += used_i
|
275 |
+
|
276 |
+
return query_prompts
|
277 |
+
|
278 |
+
|
279 |
+
def get_return_data(self, doc_set: List[int]) -> List[Tuple[str, str]]:
|
280 |
+
return_data = []
|
281 |
+
for idx in doc_set:
|
282 |
+
nctid = self.data.index2docid.iloc[idx].values[0]
|
283 |
+
return_data.append((nctid, self.data.doc_texts_df.iloc[idx].values[0]))
|
284 |
+
return return_data
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
# ------------------------------------------------------------------------------------------ #
|
289 |
+
# data prep methods that rely on model in CTMatch object (not run during routine program)
|
290 |
+
# ------------------------------------------------------------------------------------------ #
|
291 |
+
|
292 |
+
def prep_ir_text(self, doc: Dict[str, List[str]], max_len: int = 512) -> str:
|
293 |
+
inc_text = ' '.join(doc['elig_crit']['include_criteria'])
|
294 |
+
exc_text = ' '.join(doc['elig_crit']['exclude_criteria'])
|
295 |
+
all_text = f"Inclusion Criteria: {inc_text}, Exclusion Criteria: {exc_text}"
|
296 |
+
split_text = all_text.split()
|
297 |
+
return ' '.join(split_text[:min(max_len, len(split_text))])
|
298 |
+
|
299 |
+
|
300 |
+
def prep_and_save_ir_dataset(self):
|
301 |
+
category_data = self.data.get_category_data()
|
302 |
+
with open(self.pipe_config.ir_save_path, 'w') as wf:
|
303 |
+
for ir_data in self.prep_ir_data():
|
304 |
+
ir_data['categories'] = str(category_data[ir_data['id']])
|
305 |
+
wf.write(json.dumps(ir_data))
|
306 |
+
wf.write('\n')
|
307 |
+
|
308 |
+
|
309 |
+
def prep_ir_data(self):
|
310 |
+
for data_path in self.pipe_config.processed_data_paths:
|
311 |
+
for i, doc in enumerate(get_processed_data(data_path)):
|
312 |
+
if i % 10000 == 0:
|
313 |
+
logger.info(f"Prepping doc {i}")
|
314 |
+
|
315 |
+
ir_data_entry = dict()
|
316 |
+
ir_data_entry['id'] = doc['id']
|
317 |
+
doc_text = self.prep_ir_text(doc)
|
318 |
+
ir_data_entry['doc_text'] = doc_text
|
319 |
+
yield ir_data_entry
|
320 |
+
|
321 |
+
|
322 |
+
def save_texts(self) -> Dict[int, str]:
|
323 |
+
idx2id = dict()
|
324 |
+
with open(Path(self.pipe_config.ir_save_path).parent / 'texts', 'w', encoding='utf-8') as wf:
|
325 |
+
for i, doc in enumerate(get_processed_data(self.pipe_config.ir_save_path)):
|
326 |
+
idx2id[i] = doc['id']
|
327 |
+
if i % 10000 == 0:
|
328 |
+
logger.info(f"Prepping doc {i}")
|
329 |
+
|
330 |
+
wf.write(doc['doc_text'])
|
331 |
+
wf.write('\n')
|
332 |
+
return idx2id
|
333 |
+
|
ctmatch/models/classifier_model.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from tqdm.auto import tqdm
|
5 |
+
from typing import List, Tuple
|
6 |
+
|
7 |
+
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, get_scheduler
|
8 |
+
from optimum.onnxruntime import ORTModelForSequenceClassification
|
9 |
+
from optimum.onnxruntime.configuration import OptimizationConfig
|
10 |
+
from optimum.onnxruntime import ORTOptimizer
|
11 |
+
import evaluate
|
12 |
+
|
13 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
14 |
+
from sklearn.metrics import f1_score
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.optim import AdamW
|
17 |
+
from torch import nn
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from nn_pruning.patch_coordinator import ModelPatchingCoordinator, SparseTrainingArguments
|
21 |
+
from nn_pruning.inference_model_patcher import optimize_model
|
22 |
+
from nn_pruning.sparse_trainer import SparseTrainer
|
23 |
+
|
24 |
+
|
25 |
+
from ..pipeconfig import PipeConfig
|
26 |
+
from ..dataprep import DataPrep
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
PRUNED_HUB_MODEL_NAME = 'semaj83/scibert_finetuned_pruned_ctmatch'
|
32 |
+
|
33 |
+
|
34 |
+
class WeightedLossTrainer(Trainer):
|
35 |
+
def __init__(self, label_weights, *args, **kwargs):
|
36 |
+
super().__init__(*args, **kwargs)
|
37 |
+
self.label_weights = label_weights
|
38 |
+
|
39 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
40 |
+
outputs = model(**inputs)
|
41 |
+
logits = outputs.get("logits")
|
42 |
+
labels = inputs.get("labels")
|
43 |
+
loss_func = nn.CrossEntropyLoss(weight=self.label_weights)
|
44 |
+
loss = loss_func(logits, labels)
|
45 |
+
return (loss, outputs) if return_outputs else loss
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
class PruningTrainer(SparseTrainer, WeightedLossTrainer):
|
51 |
+
def __init__(self, sparse_args, *args, **kwargs):
|
52 |
+
WeightedLossTrainer.__init__(self, *args, **kwargs)
|
53 |
+
SparseTrainer.__init__(self, sparse_args)
|
54 |
+
|
55 |
+
|
56 |
+
class ClassifierModel:
|
57 |
+
|
58 |
+
def __init__(self, model_config: PipeConfig, data: DataPrep, device: str):
|
59 |
+
self.model_config = model_config
|
60 |
+
self.dataset = data.ct_dataset
|
61 |
+
self.tokenizer = data.classifier_tokenizer
|
62 |
+
self.tokenize_func = data.tokenize_function
|
63 |
+
self.trainer = None
|
64 |
+
self.optimizer = None
|
65 |
+
self.lr_scheduler = None
|
66 |
+
self.device = device
|
67 |
+
|
68 |
+
if not self.model_config.ir_setup:
|
69 |
+
self.train_dataset_df = data.ct_dataset['train'].to_pandas()
|
70 |
+
self.num_training_steps = self.model_config.train_epochs * len(self.dataset['train'])
|
71 |
+
|
72 |
+
self.model = self.load_model()
|
73 |
+
self.pruned_model = None
|
74 |
+
|
75 |
+
if not self.model_config.use_trainer and not self.model_config.ir_setup:
|
76 |
+
self.train_dataloader, self.val_dataloader = self.get_dataloaders()
|
77 |
+
|
78 |
+
|
79 |
+
if self.model_config.prune:
|
80 |
+
self.prune_trainer = None
|
81 |
+
self.sparse_args = self.get_sparse_args()
|
82 |
+
self.mpc = self.get_model_patching_coordinator()
|
83 |
+
|
84 |
+
|
85 |
+
# ------------------ Model Loading ------------------ #
|
86 |
+
def get_model(self):
|
87 |
+
if self.model_config.num_classes == 0:
|
88 |
+
return AutoModelForSequenceClassification.from_pretrained(self.model_config.classifier_model_checkpoint)
|
89 |
+
|
90 |
+
id2label, label2id = self.get_label_mapping()
|
91 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
92 |
+
self.model_config.classifier_model_checkpoint,
|
93 |
+
num_labels=self.model_config.num_classes, # makes the last head be replaced with a linear layer with num_labels outputs (fine-tuning)
|
94 |
+
id2label=id2label, label2id=label2id,
|
95 |
+
ignore_mismatched_sizes=True # because of pruned model changes
|
96 |
+
)
|
97 |
+
|
98 |
+
if 'pruned' in self.model_config.classifier_model_checkpoint:
|
99 |
+
model = optimize_model(model, "dense")
|
100 |
+
|
101 |
+
return self.add_pad_token(model)
|
102 |
+
|
103 |
+
|
104 |
+
def add_pad_token(self, model):
|
105 |
+
if model.config.pad_token_id is None:
|
106 |
+
model.config.pad_token_id = model.config.eos_token_id
|
107 |
+
return model
|
108 |
+
|
109 |
+
|
110 |
+
def load_model(self):
|
111 |
+
self.model = self.get_model()
|
112 |
+
|
113 |
+
if self.model_config.ir_setup:
|
114 |
+
return self.model
|
115 |
+
|
116 |
+
self.optimizer = AdamW(self.model.parameters(), lr=self.model_config.learning_rate, weight_decay=self.model_config.weight_decay)
|
117 |
+
self.num_training_steps = self.model_config.train_epochs * len(self.dataset['train'])
|
118 |
+
self.lr_scheduler = get_scheduler(
|
119 |
+
name="linear",
|
120 |
+
optimizer=self.optimizer,
|
121 |
+
num_warmup_steps=self.model_config.warmup_steps,
|
122 |
+
num_training_steps=self.num_training_steps
|
123 |
+
)
|
124 |
+
|
125 |
+
if self.model_config.use_trainer and not self.model_config.prune:
|
126 |
+
self.trainer = self.get_trainer()
|
127 |
+
else:
|
128 |
+
self.model = self.model.to(self.device)
|
129 |
+
|
130 |
+
return self.model
|
131 |
+
|
132 |
+
|
133 |
+
def get_label_mapping(self):
|
134 |
+
#id2label = {idx:self.dataset['train'].features["labels"].int2str(idx) for idx in range(3)}
|
135 |
+
id2label = {'0':'not_relevant', '1':'partially_relevant', '2':'relevant'}
|
136 |
+
label2id = {v:k for k, v in id2label.items()}
|
137 |
+
return id2label, label2id
|
138 |
+
|
139 |
+
def get_label_weights(self):
|
140 |
+
label_weights = (1 - (self.train_dataset_df["labels"].value_counts().sort_index() / len(self.train_dataset_df))).values
|
141 |
+
label_weights = torch.from_numpy(label_weights).float().to("cuda")
|
142 |
+
|
143 |
+
|
144 |
+
def get_trainer(self):
|
145 |
+
return WeightedLossTrainer(
|
146 |
+
model=self.model,
|
147 |
+
optimizers=(self.optimizer, self.lr_scheduler),
|
148 |
+
args=self.get_training_args_obj(),
|
149 |
+
compute_metrics=self.compute_metrics,
|
150 |
+
train_dataset=self.dataset["train"],
|
151 |
+
eval_dataset=self.dataset["validation"],
|
152 |
+
tokenizer=self.tokenizer,
|
153 |
+
label_weights=self.get_label_weights()
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def get_training_args_obj(self):
|
158 |
+
output_dir = self.model_config.output_dir if self.model_config.output_dir is not None else self.model_config.classifier_data_path.parent.parent.as_posix()
|
159 |
+
|
160 |
+
return TrainingArguments(
|
161 |
+
output_dir=output_dir,
|
162 |
+
num_train_epochs=self.model_config.train_epochs,
|
163 |
+
learning_rate=self.model_config.learning_rate,
|
164 |
+
per_device_train_batch_size=self.model_config.batch_size,
|
165 |
+
per_device_eval_batch_size=self.model_config.batch_size,
|
166 |
+
weight_decay=self.model_config.weight_decay,
|
167 |
+
evaluation_strategy="epoch",
|
168 |
+
logging_steps=len(self.dataset["train"]) // self.model_config.batch_size,
|
169 |
+
fp16=self.model_config.fp16
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
def train_and_predict(self):
|
175 |
+
if self.trainer is not None:
|
176 |
+
self.trainer.train()
|
177 |
+
predictions = self.trainer.predict(self.dataset["test"])
|
178 |
+
logger.info(predictions.metrics.items())
|
179 |
+
else:
|
180 |
+
self.loss_func = nn.CrossEntropyLoss(weight=self.get_label_weights())
|
181 |
+
self.manual_train()
|
182 |
+
self.manual_eval()
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
# ------------------ native torch training loop ------------------ #
|
187 |
+
def get_dataloaders(self) -> Tuple[DataLoader, DataLoader]:
|
188 |
+
train_dataloader = DataLoader(self.dataset['train'], shuffle=True, batch_size=self.model_config.batch_size)
|
189 |
+
val_dataloader = DataLoader(self.dataset['validation'], batch_size=self.model_config.batch_size)
|
190 |
+
return train_dataloader, val_dataloader
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
# taken from ctmatch for messing about
|
195 |
+
def manual_train(self):
|
196 |
+
progress_bar = tqdm(range(self.num_training_steps))
|
197 |
+
self.model.train()
|
198 |
+
for epoch in range(self.model_config.train_epochs):
|
199 |
+
for batch in tqdm(self.train_dataloader):
|
200 |
+
batch = {k: v.to(self.model.device) for k, v in batch.items()}
|
201 |
+
outputs = self.model(**batch)
|
202 |
+
loss = self.loss_func(outputs.logits, batch['labels'])
|
203 |
+
#total_loss += loss.item()
|
204 |
+
loss.backward()
|
205 |
+
|
206 |
+
self.optimizer.step()
|
207 |
+
self.lr_scheduler.step()
|
208 |
+
self.optimizer.zero_grad()
|
209 |
+
|
210 |
+
self.manual_eval()
|
211 |
+
logger.info(f"{loss=}")
|
212 |
+
progress_bar.update(1)
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
def manual_eval(self):
|
218 |
+
metric = evaluate.load("f1")
|
219 |
+
self.model.eval()
|
220 |
+
for batch in self.val_dataloader:
|
221 |
+
batch = {k: v.to(self.model.device) for k, v in batch.items()}
|
222 |
+
|
223 |
+
# don't learn during evaluation
|
224 |
+
with torch.no_grad():
|
225 |
+
outputs = self.model(**batch)
|
226 |
+
|
227 |
+
logits = outputs.logits
|
228 |
+
predictions = torch.argmax(logits, dim=-1)
|
229 |
+
metric.add_batch(predictions=predictions, references=batch["labels"])
|
230 |
+
|
231 |
+
logger.info(metric.compute(average='weighted'))
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
def get_sklearn_metrics(self):
|
237 |
+
with torch.no_grad():
|
238 |
+
if self.model_config.use_trainer:
|
239 |
+
if self.model_config.prune:
|
240 |
+
self.prune_trainer.model.to(self.device)
|
241 |
+
logger.info("using pruned trainer model")
|
242 |
+
preds = self.prune_trainer.predict(self.dataset['test']).predictions
|
243 |
+
else:
|
244 |
+
preds = self.trainer.predict(self.dataset['test']).predictions
|
245 |
+
|
246 |
+
if "bart" in self.model_config.name:
|
247 |
+
preds = preds[0]
|
248 |
+
|
249 |
+
y_preds = list(preds.argmax(axis=1))
|
250 |
+
else:
|
251 |
+
|
252 |
+
if self.model_config.prune:
|
253 |
+
model = self.pruned_model.to(self.device)
|
254 |
+
else:
|
255 |
+
model = self.model.to(self.device)
|
256 |
+
y_preds = []
|
257 |
+
for input_ids in self.dataset['test']['input_ids']:
|
258 |
+
input_ids = torch.tensor(input_ids).unsqueeze(0).to(self.device)
|
259 |
+
y_pred = model(input_ids).logits.argmax().item()
|
260 |
+
y_preds.append(y_pred)
|
261 |
+
|
262 |
+
y_trues = list(self.dataset['test']['labels'])
|
263 |
+
return confusion_matrix(y_trues, y_preds), classification_report(y_trues, y_preds)
|
264 |
+
|
265 |
+
|
266 |
+
def compute_metrics(self, pred):
|
267 |
+
labels = pred.label_ids
|
268 |
+
preds = pred.predictions
|
269 |
+
if "bart" in self.model_config.name:
|
270 |
+
preds = preds[0]
|
271 |
+
|
272 |
+
preds = preds.argmax(-1)
|
273 |
+
f1 = f1_score(labels, preds, average="weighted")
|
274 |
+
return {"f1":f1}
|
275 |
+
|
276 |
+
def inference_single_example(self, topic: str, doc: str, return_preds: bool = False) -> str:
|
277 |
+
"""
|
278 |
+
desc: method to predict relevance label on new topic, doc examples
|
279 |
+
"""
|
280 |
+
ex = {'doc':doc, 'topic':topic}
|
281 |
+
with torch.no_grad():
|
282 |
+
inputs = torch.LongTensor(self.tokenize_func(ex)['input_ids']).unsqueeze(0)
|
283 |
+
outputs = self.model(inputs).logits
|
284 |
+
if return_preds:
|
285 |
+
return torch.nn.functional.softmax(outputs, dim=1).squeeze(0)
|
286 |
+
return str(outputs.argmax().item())
|
287 |
+
|
288 |
+
|
289 |
+
def batch_inference(self, topic: str, docs: List[str], return_preds: bool = False) -> List[str]:
|
290 |
+
topic_repeats = [topic for _ in range(len(docs))]
|
291 |
+
inputs = self.tokenizer(
|
292 |
+
topic_repeats, docs, return_tensors='pt',
|
293 |
+
truncation=self.model_config.truncation,
|
294 |
+
padding=self.model_config.padding,
|
295 |
+
max_length=self.model_config.max_length
|
296 |
+
)
|
297 |
+
|
298 |
+
with torch.no_grad():
|
299 |
+
outputs = torch.nn.functional.softmax(self.model(**inputs).logits, dim=1)
|
300 |
+
|
301 |
+
if return_preds:
|
302 |
+
return outputs
|
303 |
+
|
304 |
+
return outputs.argmax(dim=1).tolist()
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
# ------------------ pruning ------------------ #
|
309 |
+
|
310 |
+
def prune_model(self):
|
311 |
+
self.mpc.patch_model(self.model)
|
312 |
+
self.model.save_pretrained("models/patched")
|
313 |
+
self.prune_trainer = self.get_pruning_trainer()
|
314 |
+
self.prune_trainer.set_patch_coordinator(self.mpc)
|
315 |
+
self.prune_trainer.train()
|
316 |
+
self.mpc.compile_model(self.prune_trainer.model)
|
317 |
+
if self.model_config.push_to_hub:
|
318 |
+
# can't save the optimized model to hub
|
319 |
+
self.prune_trainer.model.push_to_hub(PRUNED_HUB_MODEL_NAME)
|
320 |
+
|
321 |
+
self.pruned_model = optimize_model(self.prune_trainer.model, "dense")
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
def get_sparse_args(self):
|
326 |
+
sparse_args = SparseTrainingArguments()
|
327 |
+
|
328 |
+
hyperparams = {
|
329 |
+
"dense_pruning_method": "topK:1d_alt",
|
330 |
+
"attention_pruning_method": "topK",
|
331 |
+
"initial_threshold": 1.0,
|
332 |
+
"final_threshold": 0.5,
|
333 |
+
"initial_warmup": 1,
|
334 |
+
"final_warmup": 3,
|
335 |
+
"attention_block_rows":32,
|
336 |
+
"attention_block_cols":32,
|
337 |
+
"attention_output_with_dense": 0
|
338 |
+
}
|
339 |
+
|
340 |
+
for k,v in hyperparams.items():
|
341 |
+
if hasattr(sparse_args, k):
|
342 |
+
setattr(sparse_args, k, v)
|
343 |
+
else:
|
344 |
+
print(f"sparse_args does not have argument {k}")
|
345 |
+
|
346 |
+
return sparse_args
|
347 |
+
|
348 |
+
|
349 |
+
def get_pruning_trainer(self):
|
350 |
+
return PruningTrainer(
|
351 |
+
sparse_args=self.sparse_args,
|
352 |
+
args=self.get_training_args_obj(),
|
353 |
+
model=self.model,
|
354 |
+
train_dataset=self.dataset["train"],
|
355 |
+
eval_dataset=self.dataset["validation"],
|
356 |
+
tokenizer=self.tokenizer,
|
357 |
+
compute_metrics=self.compute_metrics,
|
358 |
+
label_weights=self.get_label_weights()
|
359 |
+
)
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
def get_model_patching_coordinator(self):
|
364 |
+
return ModelPatchingCoordinator(
|
365 |
+
sparse_args=self.sparse_args,
|
366 |
+
device=self.device,
|
367 |
+
cache_dir="checkpoints",
|
368 |
+
logit_names="logits",
|
369 |
+
teacher_constructor=None
|
370 |
+
)
|
371 |
+
|
372 |
+
|
373 |
+
# onyx optimization
|
374 |
+
def optimize_model(self):
|
375 |
+
onnx_path = Path("onnx")
|
376 |
+
model_id = self.model_config.classifier_model_checkpoint
|
377 |
+
#assert self.pruned_model is not None, "pruned model must be loaded before optimizing"
|
378 |
+
opt_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True)
|
379 |
+
optimizer = ORTOptimizer.from_pretrained(opt_model)
|
380 |
+
optimization_config = OptimizationConfig(optimization_level=99) # enable all optimizations
|
381 |
+
optimizer.optimize(
|
382 |
+
save_dir=onnx_path,
|
383 |
+
optimization_config=optimization_config,
|
384 |
+
)
|
385 |
+
opt_model.save_pretrained(onnx_path)
|
386 |
+
self.tokenizer.save_pretrained(onnx_path)
|
387 |
+
|
388 |
+
#optimized_model = ORTModelForSequenceClassification.from_pretrained(onnx_path, file_name="model_optimized.onnx")
|
389 |
+
|
390 |
+
return opt_model
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
|
ctmatch/models/gen_model.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
from ..pipeconfig import PipeConfig
|
5 |
+
import openai
|
6 |
+
import re
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class GenModel:
|
12 |
+
def __init__(self, pipe_config: PipeConfig) -> None:
|
13 |
+
openai.api_key = pipe_config.openai_api_key
|
14 |
+
self.pipe_config = pipe_config
|
15 |
+
|
16 |
+
|
17 |
+
def gen_response(self, query_prompt: str, doc_set: Optional[List[int]] = None) -> List[int]:
|
18 |
+
"""
|
19 |
+
uses openai model to return a ranking of ids
|
20 |
+
"""
|
21 |
+
if self.pipe_config.gen_model_checkpoint == 'text-davinci-003':
|
22 |
+
response = openai.Completion.create(
|
23 |
+
model=self.pipe_config.gen_model_checkpoint,
|
24 |
+
prompt=query_prompt,
|
25 |
+
temperature=0,
|
26 |
+
max_tokens=200,
|
27 |
+
top_p=1,
|
28 |
+
frequency_penalty=0.0,
|
29 |
+
presence_penalty=0.0
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
assert doc_set is not None, "doc_set must be provided for gpt-3.5-turbo"
|
33 |
+
|
34 |
+
# for gpt-3.5-turbo
|
35 |
+
response = openai.ChatCompletion.create(
|
36 |
+
model=self.pipe_config.gen_model_checkpoint,
|
37 |
+
messages = [{'role': 'user', 'content' : query_prompt}],
|
38 |
+
temperature=0.4,
|
39 |
+
max_tokens=200,
|
40 |
+
top_p=1,
|
41 |
+
frequency_penalty=0.2,
|
42 |
+
presence_penalty=0.0
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
if self.pipe_config.gen_model_checkpoint == 'text-davinci-003':
|
47 |
+
return self.post_process_chatgpt_response(response)
|
48 |
+
return self.post_process_gptturbo_response(response, doc_set=doc_set)
|
49 |
+
|
50 |
+
|
51 |
+
def post_process_chatgpt_response(self, response):
|
52 |
+
"""
|
53 |
+
could be:
|
54 |
+
NCTID 6, NCTID 7, NCTID 5
|
55 |
+
NCTID: 6, 7, 5
|
56 |
+
6, 7, 5
|
57 |
+
'1. 195155\n2. 186848\n3. 194407'
|
58 |
+
"""
|
59 |
+
response_pattern = r"(?:NCTID\:?\s*)? ?(\d+)(?!\.)"
|
60 |
+
text = response['choices'][0]['text']
|
61 |
+
return [int(s) for s in re.findall(response_pattern, text)]
|
62 |
+
|
63 |
+
def post_process_gptturbo_response(self, response, doc_set: List[int]):
|
64 |
+
"""
|
65 |
+
could be:
|
66 |
+
'The most relevant clinical trial for this patient is ID 2, followed by ID 3. The remaining trials are not relevant for this patient's condition.'
|
67 |
+
"""
|
68 |
+
text = response['choices'][0]['message']['content']
|
69 |
+
ranking = []
|
70 |
+
for substr in text.split():
|
71 |
+
if substr.isdigit():
|
72 |
+
ranking.append(int(substr))
|
73 |
+
|
74 |
+
# the rest are arbitrarily ranked
|
75 |
+
for ncid in doc_set:
|
76 |
+
if ncid not in ranking:
|
77 |
+
ranking.append(ncid)
|
78 |
+
return ranking
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
ctmatch/pipeconfig.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, List, NamedTuple, Optional
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
|
6 |
+
class PipeConfig(NamedTuple):
|
7 |
+
name: str = 'scibert_finetuned_ctmatch'
|
8 |
+
classifier_model_checkpoint: str = 'semaj83/scibert_finetuned_ctmatch'
|
9 |
+
max_length: int = 512
|
10 |
+
padding: str = True
|
11 |
+
truncation: bool = True
|
12 |
+
batch_size: int = 16
|
13 |
+
learning_rate: float = 2e-5
|
14 |
+
train_epochs: int = 3
|
15 |
+
weight_decay: float = 0.01
|
16 |
+
warmup_steps: int = 500
|
17 |
+
seed: int = 42
|
18 |
+
splits: Dict[str, float] = {"train":0.8, "val":0.1}
|
19 |
+
classifier_data_path: Path = Path("combined_classifier_data.jsonl")
|
20 |
+
output_dir: Optional[str] = None
|
21 |
+
convert_snli: bool = False
|
22 |
+
use_trainer: bool = False
|
23 |
+
num_classes: int = 3
|
24 |
+
fp16: bool = False
|
25 |
+
early_stopping: bool = False
|
26 |
+
push_to_hub: bool = False
|
27 |
+
ir_save_path: Optional[str] = None
|
28 |
+
category_path: Optional[str] = None
|
29 |
+
processed_data_paths: Optional[List[str]] = None
|
30 |
+
max_query_length: int = 1200
|
31 |
+
category_model_checkpoint: str = "facebook/bart-large-mnli"
|
32 |
+
embedding_model_checkpoint: str = "sentence-transformers/all-MiniLM-L6-v2"
|
33 |
+
gen_model_checkpoint: str = 'text-davinci-003'
|
34 |
+
max_gen: int = 100
|
35 |
+
openai_api_key: Optional[str] = None
|
36 |
+
ir_setup: bool = False # if true, use the IR model setup, no classifier training or dataprep
|
37 |
+
filters: Optional[List[str]] = None # if provided, only use these filters for the IR model, options are {'sim', 'svm', 'classifier', 'gen'}
|
38 |
+
prune: bool = False # if true, creates a pruned classifier model
|
39 |
+
optimize: bool = False # if true, creates an optimized classifier model
|
40 |
+
|
41 |
+
|
42 |
+
|
ctmatch/pipetopic.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Any, NamedTuple
|
3 |
+
|
4 |
+
class PipeTopic(NamedTuple):
|
5 |
+
topic_text: str
|
6 |
+
embedding_vec: Any
|
7 |
+
category_vec: Any
|
8 |
+
|
9 |
+
|
10 |
+
|
ctmatch/scripts/build_combined_data.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
COMBINED_CAT_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/combined_categories.jsonl'
|
7 |
+
CAT_SAVE_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/doc_categories.txt'
|
8 |
+
INDEX2DOCID_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/index2docid.txt'
|
9 |
+
INDEX2TOPICID_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/index2topicid.txt'
|
10 |
+
|
11 |
+
def load_category_dict(cat_path=COMBINED_CAT_PATH) -> Tuple[List, Dict[str, List[float]]]:
|
12 |
+
"""
|
13 |
+
desc: gets category dict from category path
|
14 |
+
"""
|
15 |
+
sorted_cat_keys = None
|
16 |
+
|
17 |
+
with open(cat_path, 'r') as json_file:
|
18 |
+
json_list = list(json_file)
|
19 |
+
|
20 |
+
all_cat_dict = {}
|
21 |
+
for s in json_list:
|
22 |
+
s_data = json.loads(s)
|
23 |
+
nct_id, cat_dict = s_data.popitem()
|
24 |
+
|
25 |
+
if sorted_cat_keys is None:
|
26 |
+
sorted_cat_keys = sorted(cat_dict.keys())
|
27 |
+
|
28 |
+
all_cat_dict[nct_id] = [cat_dict[k] for k in sorted_cat_keys]
|
29 |
+
|
30 |
+
return sorted_cat_keys, all_cat_dict
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def load_index2id(index2id_path: str = INDEX2DOCID_PATH) -> Dict[str, int]:
|
35 |
+
"""
|
36 |
+
desc: loads id2idx from csv path
|
37 |
+
"""
|
38 |
+
index2id = {}
|
39 |
+
with open(index2id_path, 'r') as f:
|
40 |
+
for line in f:
|
41 |
+
if len(line) < 2:
|
42 |
+
continue
|
43 |
+
idx, nct_id = line.split(',')
|
44 |
+
index2id[idx] = nct_id.strip(' \n')
|
45 |
+
|
46 |
+
return index2id
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def build_cat_csv(save_path: str = CAT_SAVE_PATH) -> None:
|
51 |
+
"""
|
52 |
+
desc: builds csv file for category data
|
53 |
+
VERY important that the indexes (order) match the order of the embeddings (for nctid lookup in idx2id)
|
54 |
+
"""
|
55 |
+
sorted_cat_keys, cat_dict = load_category_dict()
|
56 |
+
idx2id = load_index2id()
|
57 |
+
|
58 |
+
with open(save_path, 'w') as f:
|
59 |
+
f.write(','.join(sorted_cat_keys))
|
60 |
+
f.write('\n')
|
61 |
+
for _, nct_id in idx2id.items():
|
62 |
+
cat_vec = cat_dict[nct_id]
|
63 |
+
cat_vec_str = ','.join([str(c) for c in cat_vec])
|
64 |
+
f.write(cat_vec_str)
|
65 |
+
f.write('\n')
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
build_cat_csv()
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
ctmatch/scripts/gen_categories.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Generator, List, Optional, Tuple
|
3 |
+
from ctmatch.utils.ctmatch_utils import get_processed_data
|
4 |
+
from ctmatch.ct_data_paths import get_data_tuples
|
5 |
+
from transformers import pipeline
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
|
9 |
+
CAT_GEN_MODEL = "facebook/bart-large-mnli"
|
10 |
+
#CAT_GEN_MODEL = "microsoft/biogpt"
|
11 |
+
|
12 |
+
CT_CATEGORIES = [
|
13 |
+
"pulmonary", "cardiac", "gastrointestinal", "renal", "psychological", "genetic", "pediatric",
|
14 |
+
"neurological", "cancer", "reproductive", "endocrine", "infection", "healthy", "other"
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
19 |
+
# this script is for applying zero-shot classification labels from 'facebook/bart-large-mnli' to the
|
20 |
+
# documents of the dataset, including test, because we can assume this is something that is realistic to pre-compute
|
21 |
+
# since you have the documents apriori
|
22 |
+
# --------------------------------------------------------------------------------------------------------------- #
|
23 |
+
GET_ONLY = None
|
24 |
+
|
25 |
+
|
26 |
+
def stream_condition_data(data_chunk, doc_or_topic: str = 'doc') -> Generator[str, None, None]:
|
27 |
+
for d in data_chunk:
|
28 |
+
if doc_or_topic == 'topic':
|
29 |
+
yield d['raw_text']
|
30 |
+
else:
|
31 |
+
condition = d['condition']
|
32 |
+
if len(condition) == 0:
|
33 |
+
yield 'no information'
|
34 |
+
else:
|
35 |
+
yield ' '.join(condition).lower()
|
36 |
+
|
37 |
+
|
38 |
+
def add_condition_category_labels(
|
39 |
+
trec_or_kz: str = 'trec',
|
40 |
+
model_checkpoint=CAT_GEN_MODEL,
|
41 |
+
start: int = 0,
|
42 |
+
doc_tuples: Optional[List[Tuple[str, str]]] = None,
|
43 |
+
category_label='category',
|
44 |
+
doc_or_topic: str = 'doc'
|
45 |
+
) -> None:
|
46 |
+
pipe = pipeline(model=model_checkpoint, device=0)
|
47 |
+
chunk_size = 1000
|
48 |
+
|
49 |
+
# open the processed documents and add the category labels
|
50 |
+
if doc_tuples is None:
|
51 |
+
doc_tuples, _ = get_data_tuples(trec_or_kz=trec_or_kz)
|
52 |
+
|
53 |
+
for _, target in doc_tuples:
|
54 |
+
print(f"reading and writing to: {target}")
|
55 |
+
data = [d for d in get_processed_data(target, get_only=GET_ONLY)]
|
56 |
+
print(f"got {len(data)} records from {target}...")
|
57 |
+
|
58 |
+
# overwrite with new records having inferred category feature
|
59 |
+
with open('/content/drive/MyDrive/ct_data23/processed_trec_topic_X.jsonl', 'w') as f:
|
60 |
+
i = start
|
61 |
+
print(f'starting at: {i}')
|
62 |
+
while i < len(data):
|
63 |
+
next_chunk_end = min(len(data), i+chunk_size)
|
64 |
+
conditions = stream_condition_data(data[i:next_chunk_end], doc_or_topic=doc_or_topic)
|
65 |
+
categories = gen_categories(pipe, conditions)
|
66 |
+
print(f"generated {len(categories)} categories for {chunk_size} conditions...")
|
67 |
+
for j in range(i, next_chunk_end):
|
68 |
+
data[j][category_label] = categories[j - i]
|
69 |
+
f.write(json.dumps(data[j]))
|
70 |
+
f.write('\n')
|
71 |
+
|
72 |
+
if doc_or_topic == 'doc':
|
73 |
+
print(f"{i=}, doc condition: {data[i]['condition']}, generated category: {data[i]['category'].items()}")
|
74 |
+
else:
|
75 |
+
print(f"{i=}, topic raw text condition: {data[i]['raw_text']}, generated category: {data[i]['category'].items()}")
|
76 |
+
|
77 |
+
i += chunk_size
|
78 |
+
|
79 |
+
|
80 |
+
def gen_categories(pipe, text_dataset: Generator[str, None, None]) -> str:
|
81 |
+
categories = []
|
82 |
+
for output in pipe(text_dataset, candidate_labels=CT_CATEGORIES, batch_size=64):
|
83 |
+
score_dict = {output['labels'][i]:output['scores'][i] for i in range(len(output['labels']))}
|
84 |
+
#category = max(score_dict, key=score_dict.get)
|
85 |
+
categories.append(score_dict)
|
86 |
+
return categories
|
87 |
+
|
88 |
+
|
89 |
+
def gen_single_category_vector(pipe, text: str) -> str:
|
90 |
+
output = pipe(text, candidate_labels=CT_CATEGORIES)
|
91 |
+
score_dict = {output['labels'][i]:output['scores'][i] for i in range(len(output['labels']))}
|
92 |
+
return np.array(sorted(score_dict, key=score_dict.get, reverse=True))
|
ctmatch/scripts/get_web_data.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from selenium import webdriver
|
3 |
+
|
4 |
+
def save_web_data(url: str) -> None:
|
5 |
+
driver = webdriver.Chrome()
|
6 |
+
driver.get(url)
|
7 |
+
button = driver.find_element_by_class_name("save-list")
|
8 |
+
button.click()
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
url = "https://clinicaltrials.gov/ct2/results?cond=Heart+Diseases"
|
13 |
+
save_web_data(url)
|
14 |
+
|
ctmatch/scripts/split_files.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument('folder',
|
10 |
+
help="supply a folder path to be split up. if not folder, method won't do anything")
|
11 |
+
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
MAX_FOLDER_SIZE = 2000
|
15 |
+
|
16 |
+
|
17 |
+
def split_files(folder: Path):
|
18 |
+
|
19 |
+
assert folder.is_dir()
|
20 |
+
num_dirs = 1
|
21 |
+
curr_size = 0
|
22 |
+
|
23 |
+
new_subfolder_path = folder.parent / f"{folder.as_posix()}_{num_dirs}"
|
24 |
+
new_subfolder_path.mkdir(exist_ok=True)
|
25 |
+
for file in folder.iterdir():
|
26 |
+
if curr_size > MAX_FOLDER_SIZE:
|
27 |
+
num_dirs += 1
|
28 |
+
new_subfolder_path = folder.parent / f"{folder.as_posix()}_{num_dirs}"
|
29 |
+
new_subfolder_path.mkdir(exist_ok=True)
|
30 |
+
curr_size = 0
|
31 |
+
else:
|
32 |
+
curr_size += 1
|
33 |
+
file.rename(new_subfolder_path / file.name)
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
split_files(Path(args.folder))
|
37 |
+
|
ctmatch/scripts/vis_script.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, List, NamedTuple
|
3 |
+
from ctproc.proc import CTDocument, EligCrit
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from collections import defaultdict
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from lxml import etree
|
8 |
+
import pandas
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
from utils.ctmatch_utils import *
|
13 |
+
|
14 |
+
class FieldCounter(NamedTuple):
|
15 |
+
missfld_counts: Dict[str, int] = defaultdict(int)
|
16 |
+
emptfld_counts: Dict[str, int] = defaultdict(int)
|
17 |
+
elig_form_counts: Dict[str, int] = defaultdict(int)
|
18 |
+
unit_counts: Dict[str, int] = defaultdict(int)
|
19 |
+
|
20 |
+
|
21 |
+
#----------------------------------------------------------------#
|
22 |
+
# EDA Utility Functions
|
23 |
+
#----------------------------------------------------------------#
|
24 |
+
|
25 |
+
# viewing
|
26 |
+
def print_elig_result(doc, dont_print=[]):
|
27 |
+
for k, v in doc.elig_crit.__dict__.items():
|
28 |
+
if k in dont_print:
|
29 |
+
continue
|
30 |
+
if type(v) == list:
|
31 |
+
print('\n' + k)
|
32 |
+
for v_i in v:
|
33 |
+
print(v_i)
|
34 |
+
else:
|
35 |
+
print(f"{k}: {v}")
|
36 |
+
|
37 |
+
|
38 |
+
def display_elig(docs: List[CTDocument]) -> None:
|
39 |
+
age_in_elig_text_dist = count_elig_crit_age_in_text(docs)
|
40 |
+
total = sum(age_in_elig_text_dist.values())
|
41 |
+
print(f"{total} out of {len(docs)} documents had age in eligibility text: {total / len(docs)}%")
|
42 |
+
|
43 |
+
age_in_elig_counts_df = pandas.DataFrame(age_in_elig_text_dist, index=[0])
|
44 |
+
age_in_elig_counts_df.plot(kind="bar", xticks=[], xlabel="include_or_exclude", ylabel="count", title="Age in Eligibility Criteria Text Distribution")
|
45 |
+
print(age_in_elig_counts_df)
|
46 |
+
inc_ratio = age_in_elig_text_dist['inc_ct'] / total
|
47 |
+
exc_ratio = age_in_elig_text_dist['exc_ct'] / total
|
48 |
+
print(f"{age_in_elig_text_dist['inc_ct']} instances in inclusion statements ({inc_ratio}%), {age_in_elig_text_dist['exc_ct']} instances in exclusion statements ({exc_ratio}%)")
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def get_lengths(processed_docs: List[Dict[str, str]]) -> None:
|
53 |
+
no_crit, miss_inc, miss_exc = 0, 0, 0
|
54 |
+
inc_lens, exc_lens, all_lens = 0, 0, 0
|
55 |
+
|
56 |
+
for i, d in enumerate(processed_docs):
|
57 |
+
crit = d['elig_crit']['raw_text']
|
58 |
+
inc_crit = d['elig_crit']['include_criteria']
|
59 |
+
exc_crit = d['elig_crit']['exclude_criteria']
|
60 |
+
|
61 |
+
if len(inc_crit) == 0:
|
62 |
+
miss_inc += 1
|
63 |
+
|
64 |
+
if len(exc_crit) == 0:
|
65 |
+
miss_exc += 1
|
66 |
+
|
67 |
+
if (len(exc_crit) == 0) and (len(inc_crit) == 0):
|
68 |
+
no_crit += 1
|
69 |
+
|
70 |
+
#print(crit)
|
71 |
+
|
72 |
+
inc_length = sum([len(c.split()) for c in inc_crit])
|
73 |
+
exc_length = sum([len(c.split()) for c in exc_crit])
|
74 |
+
crit_len = inc_length + exc_length
|
75 |
+
inc_lens += inc_length
|
76 |
+
exc_lens += exc_length
|
77 |
+
all_lens += crit_len
|
78 |
+
|
79 |
+
print(f"{miss_inc=}, {miss_exc=}, {no_crit=}, {inc_lens / len(processed_docs)}, {exc_lens / len(processed_docs)}, {all_lens / len(processed_docs)}")
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def print_ent_sent(ent_sent):
|
84 |
+
for e in ent_sent:
|
85 |
+
e_small = {}
|
86 |
+
e_small['raw_text'] = e['raw_text']
|
87 |
+
e_small['start'] = e['start']
|
88 |
+
e_small['end'] = e['end']
|
89 |
+
e_small['negation'] = e['negation']
|
90 |
+
print(e_small.items())
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
#--------------------------------------------------------------------------------------#
|
98 |
+
# methods for getting counts
|
99 |
+
#--------------------------------------------------------------------------------------#
|
100 |
+
|
101 |
+
def process_counts(zip_data: str) -> FieldCounter:
|
102 |
+
"""
|
103 |
+
desc: main method for processing a zipped file of clinical trial XML documents from clinicaltrials.gov
|
104 |
+
parameterized by CTConfig the self ClinProc object was initialized with
|
105 |
+
returns: yields processed CTDocuments one at a time
|
106 |
+
"""
|
107 |
+
|
108 |
+
counts = FieldCounter()
|
109 |
+
with ZipFile(zip_data, 'r') as zip_reader:
|
110 |
+
for i, ct_file in enumerate(zip_reader.namelist()):
|
111 |
+
if i % 1000 == 0:
|
112 |
+
print(f"{i} docs processed")
|
113 |
+
|
114 |
+
if not ct_file.endswith('xml'):
|
115 |
+
continue
|
116 |
+
|
117 |
+
counts = get_ct_file_counts(zip_reader.open(ct_file), counts)
|
118 |
+
return counts
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def get_ct_file_counts(xml_filereader, counts: FieldCounter) -> FieldCounter:
|
124 |
+
doc_tree = etree.parse(xml_filereader)
|
125 |
+
root = doc_tree.getroot()
|
126 |
+
|
127 |
+
# adding new keys vs subdictionaries?????
|
128 |
+
required_fields = {
|
129 |
+
"id":None,
|
130 |
+
"brief_title":None,
|
131 |
+
"eligibility/criteria/textblock":None,
|
132 |
+
"eligibility/gender":"Default Value",
|
133 |
+
"eligibility/minimum_age":{"male":0, "female":0},
|
134 |
+
"eligibility/maximum_age":{"male":999., "female":999.},
|
135 |
+
"detailed_description/textblock":None,
|
136 |
+
"condition":None,
|
137 |
+
"condition/condition_browse":None,
|
138 |
+
"intervention/intervention_type":None,
|
139 |
+
"intervention/intervention_name":None,
|
140 |
+
"intervention_browse/mesh_term":None,
|
141 |
+
"brief_summary/textblock":None,
|
142 |
+
}
|
143 |
+
|
144 |
+
for field in required_fields.keys():
|
145 |
+
field_tag = 'id_info/nct_id' if field == 'id' else field
|
146 |
+
try:
|
147 |
+
field_val = root.find(field_tag).text
|
148 |
+
if not EMPTY_PATTERN.fullmatch(field_val):
|
149 |
+
if field == 'eligibility/criteria/textblock':
|
150 |
+
counts.elig_form_counts = get_elig_counts(field_val, counts.elig_form_counts)
|
151 |
+
elif "age" in field:
|
152 |
+
age_match = AGE_PATTERN.match(field_val)
|
153 |
+
if age_match is not None:
|
154 |
+
unit = age_match.group('units')
|
155 |
+
if unit is not None:
|
156 |
+
counts.unit_counts[unit] += 1
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
except:
|
161 |
+
if root.find(field_tag) is None:
|
162 |
+
counts.missfld_counts[field] += 1
|
163 |
+
elif EMPTY_PATTERN.fullmatch(root.find(field_tag).text):
|
164 |
+
counts.emptfld_counts[field] += 1
|
165 |
+
|
166 |
+
return counts
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
def get_elig_counts(elig_text: str, elig_form_counts: Dict[str, int]) -> Dict[str, int]:
|
180 |
+
assert elig_text is not None, "Eligibility text is empty"
|
181 |
+
if re.search('[Ii]nclusion [Cc]riteria:[^\w]+\n', elig_text):
|
182 |
+
if re.search('[Ee]xclusion Criteria:[^\w]+\n', elig_text):
|
183 |
+
elig_form_counts["inc_and_exc"] += 1
|
184 |
+
return elig_form_counts
|
185 |
+
else:
|
186 |
+
elig_form_counts["inc_only"] += 1
|
187 |
+
return elig_form_counts
|
188 |
+
|
189 |
+
elif re.search('[Ee]xclusion [Cc]riteria:[^\w]+\n', elig_text):
|
190 |
+
elig_form_counts["exc_only"] += 1
|
191 |
+
return elig_form_counts
|
192 |
+
|
193 |
+
else:
|
194 |
+
elig_form_counts["textblock"] += 1
|
195 |
+
return elig_form_counts
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
def get_counts(docs: List[CTDocument]):
|
201 |
+
gender_dist = defaultdict(int)
|
202 |
+
min_age_dist = defaultdict(int)
|
203 |
+
max_age_dist = defaultdict(int)
|
204 |
+
for doc in docs:
|
205 |
+
gender_dist[doc.elig_gender] += 1
|
206 |
+
min_age_dist[doc.elig_crit.elig_min_age] += 1
|
207 |
+
max_age_dist[doc.elig_max_age] += 1
|
208 |
+
return gender_dist, min_age_dist, max_age_dist
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
def get_relled(topic_id, rel_dict):
|
213 |
+
twos, ones, zeros = set(), set(), set()
|
214 |
+
for doc_id, rel in rel_dict[topic_id].items():
|
215 |
+
if rel == 1:
|
216 |
+
ones.add(doc_id)
|
217 |
+
elif rel == 2:
|
218 |
+
twos.add(doc_id)
|
219 |
+
else:
|
220 |
+
zeros.add(doc_id)
|
221 |
+
return {"twos": twos, "ones": ones, "zeros": zeros}
|
222 |
+
|
223 |
+
def scan_for_age(
|
224 |
+
elig_crit: EligCrit,
|
225 |
+
inc_or_ex: str = 'include'
|
226 |
+
) -> bool:
|
227 |
+
crit_to_scan = elig_crit.include_criteria if inc_or_ex == 'include' else elig_crit.exclude_criteria
|
228 |
+
for crit in crit_to_scan:
|
229 |
+
if re.match(r' ages? ', crit.lower()) is not None:
|
230 |
+
return True
|
231 |
+
return False
|
232 |
+
|
233 |
+
|
234 |
+
def count_elig_crit_age_in_text(docs, skip_predefined:bool = True):
|
235 |
+
age_in_elig_text_dist = defaultdict(int)
|
236 |
+
skipped = 0
|
237 |
+
for doc in docs:
|
238 |
+
if skip_predefined:
|
239 |
+
if (doc.elig_min_age != 0) or (doc.elig_max_age != 999): # author(s) have specified SOME criteria, assumes judgment prefers this field to free trex in criteria textblock
|
240 |
+
skipped += 1
|
241 |
+
continue
|
242 |
+
|
243 |
+
age_in_elig_text_dist['include'] += scan_for_age(doc.elig_crit, 'include')
|
244 |
+
age_in_elig_text_dist['exclude'] += scan_for_age(doc.elig_crit, 'exclude')
|
245 |
+
|
246 |
+
print(f"Total skipped: {skipped}")
|
247 |
+
return age_in_elig_text_dist
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
def get_missing_criteria(docs: List[CTDocument]):
|
253 |
+
missing_inc_ids, missing_exc_ids = {}, {}
|
254 |
+
for d in docs:
|
255 |
+
|
256 |
+
if len(d.elig_crit.include_criteria) == 0:
|
257 |
+
missing_inc_ids.add(d.nct_id)
|
258 |
+
|
259 |
+
if len(d.elig_crit.exclude_criteria) == 0:
|
260 |
+
missing_exc_ids.add(d.nct_id)
|
261 |
+
|
262 |
+
return missing_inc_ids, missing_exc_ids
|
263 |
+
|
264 |
+
|
265 |
+
# for evaluating effect of filtering
|
266 |
+
def get_doc_percent_elig(filtered_docs_by_topic: Dict[str, set]):
|
267 |
+
percents_elig = []
|
268 |
+
for topic_id, doc_list in filtered_docs_by_topic.items():
|
269 |
+
per = len(doc_list) / 3262.0
|
270 |
+
percents_elig.append(per)
|
271 |
+
print(topic_id, len(doc_list), per)
|
272 |
+
mean_elig = sum(percents_elig) / len(percents_elig)
|
273 |
+
print(f"Mean elgibile number of docs: {mean_elig}")
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
# plotting
|
280 |
+
|
281 |
+
def plot_counts(missfld_counts, emptfld_counts):
|
282 |
+
miss_df = pandas.DataFrame(missfld_counts, index=[0])
|
283 |
+
miss_df.plot(kind='bar', xticks=[], title="Missing Fields", ylabel="count", xlabel="field")
|
284 |
+
plt.legend(loc=(1.04, 0))
|
285 |
+
|
286 |
+
empt_df = pandas.DataFrame(emptfld_counts, index=[0])
|
287 |
+
empt_df.plot(kind='bar', xticks=[], title="Empty Fields", ylabel="count", xlabel="field")
|
288 |
+
plt.legend(loc=(1.04, 0))
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
#----------------------------------------------------------------#
|
294 |
+
# EDA Test Data Utility Functions
|
295 |
+
#----------------------------------------------------------------#
|
296 |
+
|
297 |
+
|
298 |
+
def get_test_rels(test_rels):
|
299 |
+
rel_dict = defaultdict(lambda:defaultdict(int))
|
300 |
+
rel_type_dict = defaultdict(int)
|
301 |
+
for line in open(test_rels, 'r').readlines():
|
302 |
+
topic_id, _, doc_id, rel = re.split(r'\s+', line.strip())
|
303 |
+
rel_dict[topic_id][doc_id] = int(rel)
|
304 |
+
rel_type_dict[rel] += 1
|
305 |
+
return rel_dict, rel_type_dict
|
306 |
+
|
307 |
+
def analyze_test_rels(test_rels_path):
|
308 |
+
rel_dict, rel_type_dict = get_test_rels(test_rels_path)
|
309 |
+
|
310 |
+
print("Rel Type Results:")
|
311 |
+
for t, n in rel_type_dict.items():
|
312 |
+
print(t + ': ' + str(n))
|
313 |
+
|
314 |
+
lengths = dict()
|
315 |
+
all_qrelled_docs = set()
|
316 |
+
for tid in rel_dict.keys():
|
317 |
+
lengths[tid] = len(rel_dict[tid])
|
318 |
+
for d in rel_dict[tid].keys():
|
319 |
+
all_qrelled_docs.add(d)
|
320 |
+
for topic, num_relled in lengths.items():
|
321 |
+
print(topic, num_relled)
|
322 |
+
print(f"Total relled: {len(all_qrelled_docs)}")
|
323 |
+
return rel_type_dict, rel_dict, all_qrelled_docs
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == '__main__':
|
331 |
+
qrels_path = '/Users/jameskelly/Documents/cp/ctmatch/data/qrels-clinical_trials.txt'
|
332 |
+
rel_type_dict, rel_dict, all_qrelled_docs = analyze_test_rels(qrels_path)
|
333 |
+
#docs_path = '/Users/jameskelly/Documents/cp/ctproc/clinicaltrials.gov-16_dec_2015_17.zip'
|
334 |
+
#counts = process_counts(docs_path)
|
ctmatch/utils/__init__.py
ADDED
File without changes
|
ctmatch/utils/ctmatch_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Any, Dict, List, Optional, Set
|
3 |
+
from sklearn.metrics.pairwise import linear_kernel
|
4 |
+
from collections import defaultdict
|
5 |
+
from numpy.linalg import norm
|
6 |
+
from datasets import Dataset
|
7 |
+
import numpy as np
|
8 |
+
import json
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
#----------------------------------------------------------------#
|
16 |
+
# global regex patterns for use throughout the methods
|
17 |
+
#----------------------------------------------------------------#
|
18 |
+
|
19 |
+
|
20 |
+
EMPTY_PATTERN = re.compile('[\n\s]+')
|
21 |
+
"""
|
22 |
+
both_inc_and_exc_pattern = re.compile(r\"\"\"[\s\n]*[Ii]nclusion [Cc]riteria:? # top line of both
|
23 |
+
(?:[ ]+[Ee]ligibility[ \w]+\:[ ])? # could contain this unneeded bit next
|
24 |
+
(?P<include_crit>[ \n\-\.\?\"\%\r\w\:\,\(\)]*) # this should get all inclusion criteria as a string
|
25 |
+
[Ee]xclusion[ ][Cc]riteria:? # delineator to exclusion criteria
|
26 |
+
(?P<exclude_crit>[\w\W ]*) # exclusion criteria as string
|
27 |
+
\"\"\", re.VERBOSE)
|
28 |
+
"""
|
29 |
+
INC_ONLY_PATTERN = re.compile('[\s\n]+[Ii]nclusion [Cc]riteria:?([\w\W ]*)')
|
30 |
+
EXC_ONLY_PATTERN = re.compile('[\n\r ]+[Ee]xclusion [Cc]riteria:?([\w\W ]*)')
|
31 |
+
AGE_PATTERN = re.compile('(?P<age>\d+) *(?P<units>\w+).*')
|
32 |
+
YEAR_PATTERN = re.compile('(?P<year>[yY]ears?.*)')
|
33 |
+
MONTH_PATTERN = re.compile('(?P<month>[mM]o(?:nth)?)')
|
34 |
+
WEEK_PATTERN = re.compile('(?P<week>[wW]eeks?)')
|
35 |
+
|
36 |
+
BOTH_INC_AND_EXC_PATTERN = re.compile("[\s\n]*[Ii]nclusion [Cc]riteria:?(?: +[Ee]ligibility[ \w]+\: )?(?P<include_crit>[ \n\-\.\?\"\%\r\w\:\,\(\)]*)[Ee]xclusion [Cc]riteria:?(?P<exclude_crit>[\w\W ]*)")
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
# -------------------------------------------------------------------------------------- #
|
41 |
+
# pretokenization utils (should be in a tokenizer...)
|
42 |
+
# -------------------------------------------------------------------------------------- #
|
43 |
+
|
44 |
+
def truncate(s: str, max_tokens: Optional[int] = None) -> str:
|
45 |
+
if max_tokens is None:
|
46 |
+
return s
|
47 |
+
s_tokens = s.split()
|
48 |
+
return ' '.join(s_tokens[:min(len(s_tokens), max_tokens)])
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
# -------------------------------------------------------------------------------------- #
|
53 |
+
# I/O utils
|
54 |
+
# -------------------------------------------------------------------------------------- #
|
55 |
+
|
56 |
+
def save_docs_jsonl(docs: List[Any], writefile: str) -> None:
|
57 |
+
"""
|
58 |
+
desc: iteratively writes contents of docs as jsonl to writefile
|
59 |
+
"""
|
60 |
+
with open(writefile, "w") as outfile:
|
61 |
+
for doc in docs:
|
62 |
+
json.dump(doc, outfile)
|
63 |
+
outfile.write("\n")
|
64 |
+
|
65 |
+
|
66 |
+
def get_processed_data(proc_loc: str, get_only: Optional[Set[str]] = None):
|
67 |
+
"""
|
68 |
+
proc_loc: str or path to location of docs in jsonl form
|
69 |
+
"""
|
70 |
+
with open(proc_loc, 'r') as json_file:
|
71 |
+
json_list = list(json_file)
|
72 |
+
|
73 |
+
if get_only is None:
|
74 |
+
for json_str in json_list:
|
75 |
+
yield json.loads(json_str)
|
76 |
+
|
77 |
+
else:
|
78 |
+
for s in json_list:
|
79 |
+
s_data = json.loads(s)
|
80 |
+
if s_data["id"] in get_only:
|
81 |
+
yield s_data
|
82 |
+
get_only.remove(s_data['id'])
|
83 |
+
if len(get_only) == 0:
|
84 |
+
return
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def train_test_val_split(dataset, splits: Dict[str, float], seed: int = 37) -> Dataset:
|
91 |
+
"""
|
92 |
+
splits a dataset having only "train" into one having train, test, val, with
|
93 |
+
split sizes determined by splits["train"] and splits["val"] (dict must have those keys)
|
94 |
+
|
95 |
+
"""
|
96 |
+
dataset = dataset["train"].train_test_split(train_size=splits["train"], seed=seed)
|
97 |
+
train = dataset["train"]
|
98 |
+
sub = train.train_test_split(test_size=splits["val"], seed=seed)
|
99 |
+
new_train = sub["train"]
|
100 |
+
new_val = sub["test"]
|
101 |
+
dataset["train"] = new_train
|
102 |
+
dataset["validation"] = new_val
|
103 |
+
return dataset
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
#----------------------------------------------------------------#
|
108 |
+
# computation methods
|
109 |
+
#----------------------------------------------------------------#
|
110 |
+
|
111 |
+
def exclusive_argmax(vector: np.ndarray) -> np.ndarray:
|
112 |
+
mask = np.zeros(len(vector))
|
113 |
+
argmax = np.argmax(vector)
|
114 |
+
vector = vector * mask
|
115 |
+
vector[argmax] = 1
|
116 |
+
return vector
|
117 |
+
|
118 |
+
|
119 |
+
#----------------------------------------------------------------#
|
120 |
+
# evaluation methods (duplicated from ctproc scripts)
|
121 |
+
#----------------------------------------------------------------#
|
122 |
+
|
123 |
+
def get_test_rels(rel_path):
|
124 |
+
rel_dict = defaultdict(lambda:defaultdict(int))
|
125 |
+
rel_type_dict = defaultdict(int)
|
126 |
+
for line in open(rel_path, 'r').readlines():
|
127 |
+
topic_id, _, doc_id, rel = re.split(r'\s+', line.strip())
|
128 |
+
rel_dict[topic_id][doc_id] = int(rel)
|
129 |
+
rel_type_dict[rel] += 1
|
130 |
+
return rel_dict, rel_type_dict
|
131 |
+
|
132 |
+
|
133 |
+
|
ctmatch/utils/eval_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
from sklearn.metrics import f1_score
|
5 |
+
from collections import defaultdict
|
6 |
+
from lxml import etree
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def get_trec_topic2text(topic_path) -> Dict[str, str]:
|
11 |
+
"""
|
12 |
+
desc: main method for processing a single XML file of TREC21 patient descriptions called "topics" in this sense
|
13 |
+
returns: dict of topicid: topic text
|
14 |
+
"""
|
15 |
+
|
16 |
+
topic2text = {}
|
17 |
+
topic_root = etree.parse(topic_path).getroot()
|
18 |
+
for topic in topic_root:
|
19 |
+
topic2text[topic.attrib['number']] = topic.text
|
20 |
+
|
21 |
+
return topic2text
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def get_kz_topic2text(topic_path) -> Dict[str, str]:
|
26 |
+
"""
|
27 |
+
desc: main method for processing a single XML file of TREC21 patient descriptions called "topics" in this sense
|
28 |
+
returns: dict of topicid: topic text
|
29 |
+
"""
|
30 |
+
|
31 |
+
topic2text = {}
|
32 |
+
with open(topic_path, 'r') as f:
|
33 |
+
for line in f.readlines():
|
34 |
+
line = line.strip()
|
35 |
+
|
36 |
+
if line.startswith('<TOP>'):
|
37 |
+
topic_id, text = None, None
|
38 |
+
continue
|
39 |
+
|
40 |
+
if line.startswith('<NUM>'):
|
41 |
+
topic_id = line[5:-6]
|
42 |
+
|
43 |
+
elif line.startswith('<TITLE>'):
|
44 |
+
text = line[7:].strip()
|
45 |
+
topic2text[topic_id] = text
|
46 |
+
|
47 |
+
return topic2text
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def calc_first_positive_rank(ranked_ids: List[str], doc2rel: Dict[str, int], pos_val: int = 2) -> Tuple[int, float]:
|
52 |
+
"""
|
53 |
+
desc: compute the mean reciprocal rank of a ranking
|
54 |
+
returns: mrr
|
55 |
+
"""
|
56 |
+
for i, doc_id in enumerate(ranked_ids):
|
57 |
+
if doc2rel[doc_id] == pos_val:
|
58 |
+
return i + 1, 1./float(i+1)
|
59 |
+
return len(ranked_ids) + 1, 0.0
|
60 |
+
|
61 |
+
|
62 |
+
def calc_f1(ranked_ids: List[str], doc2rel: Dict[str, int]) -> Dict[str, Dict[str, float]]:
|
63 |
+
label_counts = get_label_counts(doc2rel)
|
64 |
+
predicted, ground_truth = [], []
|
65 |
+
for doc_id in ranked_ids:
|
66 |
+
# 2, 1, 0
|
67 |
+
ground_truth.append(doc2rel[doc_id])
|
68 |
+
pred_label = get_predicted_label(label_counts)
|
69 |
+
predicted.append(pred_label)
|
70 |
+
label_counts[pred_label] -= 1
|
71 |
+
|
72 |
+
return f1_score(ground_truth, predicted, average='micro')
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def get_label_counts(doc2rel: Dict[str, int]) -> Dict[int, int]:
|
77 |
+
"""
|
78 |
+
return an ordered list of [(2, <count_2s>), (1, <count_1s>), (0, count_0s)]
|
79 |
+
"""
|
80 |
+
label_counts = defaultdict(int)
|
81 |
+
for scored_doc in doc2rel:
|
82 |
+
label = doc2rel[scored_doc]
|
83 |
+
label_counts[label] += 1
|
84 |
+
return label_counts
|
85 |
+
|
86 |
+
def get_predicted_label(label_counts: Dict[int, int]) -> int:
|
87 |
+
if label_counts[2] > 0:
|
88 |
+
return 2
|
89 |
+
if label_counts[1] > 0:
|
90 |
+
return 1
|
91 |
+
return 0
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ctproc uncoment if doing data prep on raw ct documents
|
2 |
+
#https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_md-0.4.0.tar.gz uncomment if using ctproc
|
3 |
+
#pyserini==0.12.0 uncomment if using ctproc with indexes (not recommended)
|
4 |
+
#git+https://github.com/semajyllek/transformers.git@add-biogpt-sequenceclassifier
|
5 |
+
#sacremoses uncomment if using biogpt
|
6 |
+
sentence-transformers
|
7 |
+
huggingface_hub
|
8 |
+
scikit-learn
|
9 |
+
transformers
|
10 |
+
onnxruntime
|
11 |
+
nn_pruning
|
12 |
+
optimum
|
13 |
+
onnx
|
14 |
+
|
15 |
+
matplotlib
|
16 |
+
accelerate
|
17 |
+
datasets
|
18 |
+
evaluate
|
19 |
+
pandas
|
20 |
+
openai
|
21 |
+
lxml
|
22 |
+
|
23 |
+
gradio
|