wenruifan's picture
Upload 115 files
a256709 verified
"""
Code copied from AGXNet:
https://github.com/batmanlab/AGXNet
"""
import argparse
import pandas as pd
import json
from tqdm import tqdm
import nltk
parser = argparse.ArgumentParser(description="Itemize RadGraph Dataset.")
parser.add_argument(
"--data-path",
default="/PATH TO RADGRAPH DATA/RadGraph/physionet.org/files/radgraph/1.0.0/MIMIC-CXR_graphs.json",
help="RadGraph data path.",
)
parser.add_argument(
"--output-path",
default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-itemized.csv",
help="Output path for itemized RadGraph data.",
)
def get_ids(key):
"""Convert keys in the RadGraph file into IDs"""
lst = key.split("/")
partition = lst[0] # dataset partition
pid = lst[1][1:] # patient id
sid = lst[2].split(".")[0][1:] # study id, remove .txt
return partition, pid, sid
def get_sen_from_token_ix(text, ix):
"""get the sentence to which the input token index belongs."""
sen_lst = nltk.sent_tokenize(text)
dict_ws = {}
ix_w = 0
ix_s = 0
for s in sen_lst:
words = nltk.word_tokenize(s)
for w in words:
dict_ws[ix_w] = ix_s
ix_w += 1
ix_s += 1
return dict_ws[ix], sen_lst[dict_ws[ix]]
def get_entity_relation(value):
"""itemize each relation"""
source_lst = []
target_lst = []
token_lst = []
token_ix_lst = []
label_lst = []
relation_lst = []
sen_lst = []
sen_ix_lst = []
text = value["text"]
entities = value["entities"]
for k, v in entities.items():
six, sen = get_sen_from_token_ix(text, v["start_ix"])
relations = v["relations"]
# source node has no out going edge
if (len(relations) == 0) or (relations[0] is None):
source_lst.append(k)
token_ix_lst.append(v["start_ix"])
token_lst.append(v["tokens"])
label_lst.append(v["label"])
relation_lst.append(None)
target_lst.append(None)
sen_ix_lst.append(six)
sen_lst.append(sen)
else:
for r in relations:
source_lst.append(k)
token_ix_lst.append(v["start_ix"])
token_lst.append(v["tokens"])
label_lst.append(v["label"])
relation_lst.append(r[0])
target_lst.append(r[1])
sen_ix_lst.append(six)
sen_lst.append(sen)
# save outputs in a dataframe
return pd.DataFrame(
{
"source": source_lst,
"token": token_lst,
"token_ix": token_ix_lst,
"label": label_lst,
"relation": relation_lst,
"target": target_lst,
"sentence_ix": sen_ix_lst,
"sentence": sen_lst,
}
)
def radgraph_itemize(args):
"""Convert nested RadGraph data to itemized examples."""
print("Loading RadGraph data...")
f = open(args.data_path)
data = json.load(f)
print("RadGraph data is loaded.")
# create itemized RadGraph data
df_lst = []
pid_lst = []
sid_lst = []
text_lst = []
print("Itemizing RadGraph data...")
for key, value in tqdm(data.items()):
_, pid, sid = get_ids(key)
pid_lst.append(pid)
sid_lst.append(sid)
text_lst.append(data[key]["text"])
df = get_entity_relation(value)
df["subject_id"] = pid
df["study_id"] = sid
df_lst.append(df)
# entity level dataframe
df_itemized = pd.concat(df_lst)
# save dataframes to a .csv file
df_itemized.to_csv(args.output_path, index=False)
print("Outputs have been saved!")
if __name__ == "__main__":
args = parser.parse_args()
radgraph_itemize(args)