|
""" |
|
Code copied from AGXNet: |
|
https://github.com/batmanlab/AGXNet |
|
""" |
|
|
|
import argparse |
|
import pandas as pd |
|
import json |
|
from tqdm import tqdm |
|
import nltk |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Itemize RadGraph Dataset.") |
|
|
|
parser.add_argument( |
|
"--data-path", |
|
default="/PATH TO RADGRAPH DATA/RadGraph/physionet.org/files/radgraph/1.0.0/MIMIC-CXR_graphs.json", |
|
help="RadGraph data path.", |
|
) |
|
parser.add_argument( |
|
"--output-path", |
|
default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-itemized.csv", |
|
help="Output path for itemized RadGraph data.", |
|
) |
|
|
|
|
|
def get_ids(key): |
|
"""Convert keys in the RadGraph file into IDs""" |
|
lst = key.split("/") |
|
partition = lst[0] |
|
pid = lst[1][1:] |
|
sid = lst[2].split(".")[0][1:] |
|
return partition, pid, sid |
|
|
|
|
|
def get_sen_from_token_ix(text, ix): |
|
"""get the sentence to which the input token index belongs.""" |
|
sen_lst = nltk.sent_tokenize(text) |
|
dict_ws = {} |
|
ix_w = 0 |
|
ix_s = 0 |
|
for s in sen_lst: |
|
words = nltk.word_tokenize(s) |
|
for w in words: |
|
dict_ws[ix_w] = ix_s |
|
ix_w += 1 |
|
ix_s += 1 |
|
return dict_ws[ix], sen_lst[dict_ws[ix]] |
|
|
|
|
|
def get_entity_relation(value): |
|
"""itemize each relation""" |
|
source_lst = [] |
|
target_lst = [] |
|
token_lst = [] |
|
token_ix_lst = [] |
|
label_lst = [] |
|
relation_lst = [] |
|
sen_lst = [] |
|
sen_ix_lst = [] |
|
|
|
text = value["text"] |
|
|
|
entities = value["entities"] |
|
for k, v in entities.items(): |
|
six, sen = get_sen_from_token_ix(text, v["start_ix"]) |
|
relations = v["relations"] |
|
|
|
|
|
if (len(relations) == 0) or (relations[0] is None): |
|
source_lst.append(k) |
|
token_ix_lst.append(v["start_ix"]) |
|
token_lst.append(v["tokens"]) |
|
label_lst.append(v["label"]) |
|
relation_lst.append(None) |
|
target_lst.append(None) |
|
sen_ix_lst.append(six) |
|
sen_lst.append(sen) |
|
else: |
|
for r in relations: |
|
source_lst.append(k) |
|
token_ix_lst.append(v["start_ix"]) |
|
token_lst.append(v["tokens"]) |
|
label_lst.append(v["label"]) |
|
relation_lst.append(r[0]) |
|
target_lst.append(r[1]) |
|
sen_ix_lst.append(six) |
|
sen_lst.append(sen) |
|
|
|
|
|
return pd.DataFrame( |
|
{ |
|
"source": source_lst, |
|
"token": token_lst, |
|
"token_ix": token_ix_lst, |
|
"label": label_lst, |
|
"relation": relation_lst, |
|
"target": target_lst, |
|
"sentence_ix": sen_ix_lst, |
|
"sentence": sen_lst, |
|
} |
|
) |
|
|
|
|
|
def radgraph_itemize(args): |
|
"""Convert nested RadGraph data to itemized examples.""" |
|
|
|
print("Loading RadGraph data...") |
|
f = open(args.data_path) |
|
data = json.load(f) |
|
print("RadGraph data is loaded.") |
|
|
|
|
|
df_lst = [] |
|
pid_lst = [] |
|
sid_lst = [] |
|
text_lst = [] |
|
print("Itemizing RadGraph data...") |
|
for key, value in tqdm(data.items()): |
|
_, pid, sid = get_ids(key) |
|
pid_lst.append(pid) |
|
sid_lst.append(sid) |
|
text_lst.append(data[key]["text"]) |
|
df = get_entity_relation(value) |
|
df["subject_id"] = pid |
|
df["study_id"] = sid |
|
df_lst.append(df) |
|
|
|
|
|
df_itemized = pd.concat(df_lst) |
|
|
|
|
|
df_itemized.to_csv(args.output_path, index=False) |
|
print("Outputs have been saved!") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
radgraph_itemize(args) |
|
|