|
import pandas as pd |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.utils.rnn import pad_sequence |
|
import lightning.pytorch as pl |
|
import config |
|
import sys |
|
|
|
sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag") |
|
from data_proc.data_gen import ( |
|
positive_generator, |
|
negative_generator, |
|
get_mentioned_code, |
|
) |
|
|
|
|
|
|
|
class ContrastiveLearningDataset(Dataset): |
|
def __init__( |
|
self, |
|
data: pd.DataFrame, |
|
): |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
data_row = self.data.iloc[index] |
|
sentence = data_row.sentences |
|
return sentence |
|
|
|
|
|
def max_pairwise_sim(sentence1, sentence2, current_df, query_df, sim_df, all_d): |
|
"""Returns the maximum ontology similarity score between concept pairs mentioned in sentence1 and sentence2. |
|
|
|
Args: |
|
sentence1: anchor sentence |
|
sentence2: negative sentence |
|
current_df: the dataset where anchor sentence stays |
|
query_df: the union of training and validation sets |
|
dictionary: cardiac-related {concepts: synonyms} |
|
sim_df: the dataset of pairwise ontology similarity score |
|
all_d: the dataset of [concepts, synonyms, list of ancestor concepts] |
|
""" |
|
|
|
anchor_codes = get_mentioned_code(sentence1, current_df) |
|
other_codes = get_mentioned_code(sentence2, query_df) |
|
|
|
|
|
code_pairs = list(zip(anchor_codes, other_codes)) |
|
sim_scores = [] |
|
for pair in code_pairs: |
|
code1 = pair[0] |
|
code2 = pair[1] |
|
if code1 == code2: |
|
result = len(all_d.loc[all_d["concept"] == code1, "ancestors"].values[0]) |
|
sim_scores.append(result) |
|
else: |
|
try: |
|
result = sim_df.loc[ |
|
(sim_df["Code1"] == code1) & (sim_df["Code2"] == code2), "score" |
|
].values[0] |
|
sim_scores.append(result) |
|
except: |
|
result = sim_df.loc[ |
|
(sim_df["Code1"] == code2) & (sim_df["Code2"] == code1), "score" |
|
].values[0] |
|
sim_scores.append(result) |
|
if len(sim_scores) > 0: |
|
return max(sim_scores) |
|
else: |
|
return 0 |
|
|
|
|
|
|
|
def collate_simcse(batch, tokenizer): |
|
""" |
|
Use the first sample in the batch as the anchor, |
|
use the duplicate of anchor as the positive, |
|
use the rest of the batch as negatives. |
|
""" |
|
anchor = batch[0] |
|
positive = anchor[:] |
|
negatives = batch[1:] |
|
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"]) |
|
|
|
anchor_token = tokenizer.encode_plus( |
|
anchor, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
anchor_row = pd.DataFrame( |
|
{ |
|
"label": 0, |
|
"input_ids": anchor_token["input_ids"].tolist(), |
|
"attention_mask": anchor_token["attention_mask"].tolist(), |
|
} |
|
) |
|
df = pd.concat([df, anchor_row]) |
|
|
|
pos_token = tokenizer.encode_plus( |
|
positive, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
pos_row = pd.DataFrame( |
|
{ |
|
"label": 1, |
|
"input_ids": pos_token["input_ids"].tolist(), |
|
"attention_mask": pos_token["attention_mask"].tolist(), |
|
} |
|
) |
|
df = pd.concat([df, pos_row]) |
|
|
|
for neg in negatives: |
|
neg_token = tokenizer.encode_plus( |
|
neg, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
neg_row = pd.DataFrame( |
|
{ |
|
"label": 2, |
|
"input_ids": neg_token["input_ids"].tolist(), |
|
"attention_mask": neg_token["attention_mask"].tolist(), |
|
} |
|
) |
|
df = pd.concat([df, neg_row]) |
|
|
|
label = torch.tensor(df["label"].tolist()) |
|
|
|
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) |
|
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) |
|
padded_input_ids = torch.transpose(padded_input_ids, 0, 1) |
|
|
|
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) |
|
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) |
|
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) |
|
|
|
return { |
|
"label": label, |
|
"input_ids": padded_input_ids, |
|
"attention_mask": padded_attention_mask, |
|
} |
|
|
|
|
|
def create_dataloader_simcse( |
|
dataset, |
|
tokenizer, |
|
shuffle, |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=config.batch_size_simcse, |
|
shuffle=shuffle, |
|
num_workers=config.num_workers, |
|
collate_fn=lambda batch: collate_simcse( |
|
batch, |
|
tokenizer, |
|
), |
|
) |
|
|
|
|
|
class ContrastiveLearningDataModule_simcse(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_df, |
|
val_df, |
|
tokenizer, |
|
): |
|
super().__init__() |
|
self.train_df = train_df |
|
self.val_df = val_df |
|
self.tokenizer = tokenizer |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = ContrastiveLearningDataset(self.train_df) |
|
self.val_dataset = ContrastiveLearningDataset(self.val_df) |
|
|
|
def train_dataloader(self): |
|
return create_dataloader_simcse( |
|
self.train_dataset, |
|
self.tokenizer, |
|
shuffle=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return create_dataloader_simcse( |
|
self.val_dataset, |
|
self.tokenizer, |
|
shuffle=False, |
|
) |
|
|
|
|
|
|
|
def collate_simcse_w( |
|
batch, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
sim_df, |
|
all_d, |
|
): |
|
""" |
|
Anchor: 0 |
|
Positive: 1 |
|
Negative: 2 |
|
""" |
|
anchor = batch[0] |
|
positive = anchor[:] |
|
negatives = batch[1:] |
|
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"]) |
|
|
|
anchor_token = tokenizer.encode_plus( |
|
anchor, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
|
|
anchor_row = pd.DataFrame( |
|
{ |
|
"label": 0, |
|
"input_ids": anchor_token["input_ids"].tolist(), |
|
"attention_mask": anchor_token["attention_mask"].tolist(), |
|
"score": 1, |
|
} |
|
) |
|
df = pd.concat([df, anchor_row]) |
|
|
|
pos_token = tokenizer.encode_plus( |
|
positive, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
pos_row = pd.DataFrame( |
|
{ |
|
"label": 1, |
|
"input_ids": pos_token["input_ids"].tolist(), |
|
"attention_mask": pos_token["attention_mask"].tolist(), |
|
"score": 1, |
|
} |
|
) |
|
df = pd.concat([df, pos_row]) |
|
|
|
for neg in negatives: |
|
neg_token = tokenizer.encode_plus( |
|
neg, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d) |
|
offset = 8 |
|
score = score + offset |
|
neg_row = pd.DataFrame( |
|
{ |
|
"label": 2, |
|
"input_ids": neg_token["input_ids"].tolist(), |
|
"attention_mask": neg_token["attention_mask"].tolist(), |
|
"score": score, |
|
} |
|
) |
|
df = pd.concat([df, neg_row]) |
|
|
|
label = torch.tensor(df["label"].tolist()) |
|
|
|
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) |
|
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) |
|
padded_input_ids = torch.transpose(padded_input_ids, 0, 1) |
|
|
|
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) |
|
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) |
|
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) |
|
|
|
score = torch.tensor(df["score"].tolist()) |
|
|
|
return { |
|
"label": label, |
|
"input_ids": padded_input_ids, |
|
"attention_mask": padded_attention_mask, |
|
"score": score, |
|
} |
|
|
|
|
|
def create_dataloader_simcse_w( |
|
dataset, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
sim_df, |
|
all_d, |
|
shuffle, |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=config.batch_size_simcse, |
|
shuffle=shuffle, |
|
num_workers=config.num_workers, |
|
collate_fn=lambda batch: collate_simcse_w( |
|
batch, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
sim_df, |
|
all_d, |
|
), |
|
) |
|
|
|
|
|
class ContrastiveLearningDataModule_simcse_w(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_df, |
|
val_df, |
|
query_df, |
|
tokenizer, |
|
sim_df, |
|
all_d, |
|
): |
|
super().__init__() |
|
self.train_df = train_df |
|
self.val_df = val_df |
|
self.query_df = query_df |
|
self.tokenizer = tokenizer |
|
self.sim_df = sim_df |
|
self.all_d = all_d |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = ContrastiveLearningDataset(self.train_df) |
|
self.val_dataset = ContrastiveLearningDataset(self.val_df) |
|
|
|
def train_dataloader(self): |
|
return create_dataloader_simcse_w( |
|
self.train_dataset, |
|
self.train_df, |
|
self.query_df, |
|
self.tokenizer, |
|
self.sim_df, |
|
self.all_d, |
|
shuffle=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return create_dataloader_simcse_w( |
|
self.val_dataset, |
|
self.val_df, |
|
self.query_df, |
|
self.tokenizer, |
|
self.sim_df, |
|
self.all_d, |
|
shuffle=False, |
|
) |
|
|
|
|
|
|
|
def collate_samp( |
|
sentence, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
): |
|
|
|
anchor = sentence[0] |
|
positives = positive_generator( |
|
anchor, current_df, query_df, dictionary, num_pos=config.num_pos |
|
) |
|
negatives = negative_generator( |
|
anchor, |
|
current_df, |
|
query_df, |
|
dictionary, |
|
sim_df, |
|
num_neg=config.num_neg, |
|
) |
|
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"]) |
|
anchor_token = tokenizer.encode_plus( |
|
anchor, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
|
|
anchor_row = pd.DataFrame( |
|
{ |
|
"label": 0, |
|
"input_ids": anchor_token["input_ids"].tolist(), |
|
"attention_mask": anchor_token["attention_mask"].tolist(), |
|
} |
|
) |
|
df = pd.concat([df, anchor_row]) |
|
|
|
for pos in positives: |
|
token = tokenizer.encode_plus( |
|
pos, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
row = pd.DataFrame( |
|
{ |
|
"label": 1, |
|
"input_ids": token["input_ids"].tolist(), |
|
"attention_mask": token["attention_mask"].tolist(), |
|
} |
|
) |
|
df = pd.concat([df, row]) |
|
|
|
for neg in negatives: |
|
token = tokenizer.encode_plus( |
|
neg, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
row = pd.DataFrame( |
|
{ |
|
"label": 2, |
|
"input_ids": token["input_ids"].tolist(), |
|
"attention_mask": token["attention_mask"].tolist(), |
|
} |
|
) |
|
df = pd.concat([df, row]) |
|
|
|
label = torch.tensor(df["label"].tolist()) |
|
|
|
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) |
|
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) |
|
padded_input_ids = torch.transpose(padded_input_ids, 0, 1) |
|
|
|
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) |
|
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) |
|
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) |
|
|
|
return { |
|
"label": label, |
|
"input_ids": padded_input_ids, |
|
"attention_mask": padded_attention_mask, |
|
} |
|
|
|
|
|
def create_dataloader_samp( |
|
dataset, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
shuffle, |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=config.batch_size, |
|
shuffle=shuffle, |
|
num_workers=config.num_workers, |
|
collate_fn=lambda batch: collate_samp( |
|
batch, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
), |
|
) |
|
|
|
|
|
class ContrastiveLearningDataModule_samp(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_df, |
|
val_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
): |
|
super().__init__() |
|
self.train_df = train_df |
|
self.val_df = val_df |
|
self.query_df = query_df |
|
self.tokenizer = tokenizer |
|
self.dictionary = dictionary |
|
self.sim_df = sim_df |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = ContrastiveLearningDataset(self.train_df) |
|
self.val_dataset = ContrastiveLearningDataset(self.val_df) |
|
|
|
def train_dataloader(self): |
|
return create_dataloader_samp( |
|
self.train_dataset, |
|
self.train_df, |
|
self.query_df, |
|
self.tokenizer, |
|
self.dictionary, |
|
self.sim_df, |
|
shuffle=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return create_dataloader_samp( |
|
self.val_dataset, |
|
self.val_df, |
|
self.query_df, |
|
self.tokenizer, |
|
self.dictionary, |
|
self.sim_df, |
|
shuffle=False, |
|
) |
|
|
|
|
|
|
|
def collate_samp_w( |
|
sentence, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
all_d, |
|
): |
|
""" |
|
Anchor: 0 |
|
Positive: 1 |
|
Negative: 2 |
|
""" |
|
anchor = sentence[0] |
|
positives = positive_generator( |
|
anchor, current_df, query_df, dictionary, num_pos=config.num_pos |
|
) |
|
negatives = negative_generator( |
|
anchor, |
|
current_df, |
|
query_df, |
|
dictionary, |
|
sim_df, |
|
num_neg=config.num_neg, |
|
) |
|
df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"]) |
|
anchor_token = tokenizer.encode_plus( |
|
anchor, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
|
|
anchor_row = pd.DataFrame( |
|
{ |
|
"label": 0, |
|
"input_ids": anchor_token["input_ids"].tolist(), |
|
"attention_mask": anchor_token["attention_mask"].tolist(), |
|
"score": 1, |
|
} |
|
) |
|
df = pd.concat([df, anchor_row]) |
|
|
|
for pos in positives: |
|
token = tokenizer.encode_plus( |
|
pos, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
row = pd.DataFrame( |
|
{ |
|
"label": 1, |
|
"input_ids": token["input_ids"].tolist(), |
|
"attention_mask": token["attention_mask"].tolist(), |
|
"score": 1, |
|
} |
|
) |
|
df = pd.concat([df, row]) |
|
|
|
for neg in negatives: |
|
token = tokenizer.encode_plus( |
|
neg, |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d) |
|
offset = 8 |
|
score = score + offset |
|
row = pd.DataFrame( |
|
{ |
|
"label": 2, |
|
"input_ids": token["input_ids"].tolist(), |
|
"attention_mask": token["attention_mask"].tolist(), |
|
"score": score, |
|
} |
|
) |
|
df = pd.concat([df, row]) |
|
|
|
label = torch.tensor(df["label"].tolist()) |
|
|
|
input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) |
|
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) |
|
padded_input_ids = torch.transpose(padded_input_ids, 0, 1) |
|
|
|
attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) |
|
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) |
|
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) |
|
|
|
score = torch.tensor(df["score"].tolist()) |
|
|
|
return { |
|
"label": label, |
|
"input_ids": padded_input_ids, |
|
"attention_mask": padded_attention_mask, |
|
"score": score, |
|
} |
|
|
|
|
|
def create_dataloader_samp_w( |
|
dataset, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
all_d, |
|
shuffle, |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=config.batch_size, |
|
shuffle=shuffle, |
|
num_workers=config.num_workers, |
|
collate_fn=lambda batch: collate_samp_w( |
|
batch, |
|
current_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
all_d, |
|
), |
|
) |
|
|
|
|
|
class ContrastiveLearningDataModule_samp_w(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_df, |
|
val_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
all_d, |
|
): |
|
super().__init__() |
|
self.train_df = train_df |
|
self.val_df = val_df |
|
self.query_df = query_df |
|
self.tokenizer = tokenizer |
|
self.dictionary = dictionary |
|
self.sim_df = sim_df |
|
self.all_d = all_d |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = ContrastiveLearningDataset(self.train_df) |
|
self.val_dataset = ContrastiveLearningDataset(self.val_df) |
|
|
|
def train_dataloader(self): |
|
return create_dataloader_samp_w( |
|
self.train_dataset, |
|
self.train_df, |
|
self.query_df, |
|
self.tokenizer, |
|
self.dictionary, |
|
self.sim_df, |
|
self.all_d, |
|
shuffle=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return create_dataloader_samp_w( |
|
self.val_dataset, |
|
self.val_df, |
|
self.query_df, |
|
self.tokenizer, |
|
self.dictionary, |
|
self.sim_df, |
|
self.all_d, |
|
shuffle=False, |
|
) |
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
from ast import literal_eval |
|
from sklearn.model_selection import train_test_split |
|
|
|
query_df = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv" |
|
) |
|
query_df["concepts"] = query_df["concepts"].apply(literal_eval) |
|
query_df["codes"] = query_df["codes"].apply(literal_eval) |
|
query_df["codes"] = query_df["codes"].apply( |
|
lambda x: [val for val in x if val is not None] |
|
) |
|
query_df = query_df.drop(columns=["one_hot"]) |
|
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
|
sim_df = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv" |
|
) |
|
|
|
all_d = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv" |
|
) |
|
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) |
|
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) |
|
dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) |
|
|
|
d1 = ContrastiveLearningDataModule_simcse(train_df, val_df, tokenizer) |
|
d1.setup() |
|
train_d1 = d1.train_dataloader() |
|
for batch in train_d1: |
|
b1 = batch |
|
break |
|
|
|
d2 = ContrastiveLearningDataModule_simcse_w( |
|
train_df, val_df, query_df, tokenizer, sim_df, all_d |
|
) |
|
d2.setup() |
|
train_d2 = d2.train_dataloader() |
|
for batch in train_d2: |
|
b2 = batch |
|
break |
|
|
|
d3 = ContrastiveLearningDataModule_samp( |
|
train_df, val_df, query_df, tokenizer, dictionary, sim_df |
|
) |
|
d3.setup() |
|
train_d3 = d3.train_dataloader() |
|
for batch in train_d3: |
|
b3 = batch |
|
break |
|
|
|
d4 = ContrastiveLearningDataModule_samp_w( |
|
train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d |
|
) |
|
d4.setup() |
|
train_d4 = d4.train_dataloader() |
|
for batch in train_d4: |
|
b4 = batch |
|
break |
|
|