import pandas as pd import torch from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import lightning.pytorch as pl import config import sys sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag") from data_proc.data_gen import ( positive_generator, negative_generator, get_mentioned_code, ) ##### General class ContrastiveLearningDataset(Dataset): def __init__( self, data: pd.DataFrame, ): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): data_row = self.data.iloc[index] sentence = data_row.sentences return sentence def max_pairwise_sim(sentence1, sentence2, current_df, query_df, sim_df, all_d): """Returns the maximum ontology similarity score between concept pairs mentioned in sentence1 and sentence2. Args: sentence1: anchor sentence sentence2: negative sentence current_df: the dataset where anchor sentence stays query_df: the union of training and validation sets dictionary: cardiac-related {concepts: synonyms} sim_df: the dataset of pairwise ontology similarity score all_d: the dataset of [concepts, synonyms, list of ancestor concepts] """ # retrieve concepts from the two sentences anchor_codes = get_mentioned_code(sentence1, current_df) other_codes = get_mentioned_code(sentence2, query_df) # create snomed-ct code pairs and calculate the score using sim_df code_pairs = list(zip(anchor_codes, other_codes)) sim_scores = [] for pair in code_pairs: code1 = pair[0] code2 = pair[1] if code1 == code2: result = len(all_d.loc[all_d["concept"] == code1, "ancestors"].values[0]) sim_scores.append(result) else: try: result = sim_df.loc[ (sim_df["Code1"] == code1) & (sim_df["Code2"] == code2), "score" ].values[0] sim_scores.append(result) except: result = sim_df.loc[ (sim_df["Code1"] == code2) & (sim_df["Code2"] == code1), "score" ].values[0] sim_scores.append(result) if len(sim_scores) > 0: return max(sim_scores) else: return 0 ##### SimCSE def collate_simcse(batch, tokenizer): """ Use the first sample in the batch as the anchor, use the duplicate of anchor as the positive, use the rest of the batch as negatives. """ anchor = batch[0] # use the first sample in the batch as anchor positive = anchor[:] # create a duplicate of anchor as positive negatives = batch[1:] # everything else as negatives df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"]) anchor_token = tokenizer.encode_plus( anchor, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) anchor_row = pd.DataFrame( { "label": 0, "input_ids": anchor_token["input_ids"].tolist(), "attention_mask": anchor_token["attention_mask"].tolist(), } ) df = pd.concat([df, anchor_row]) pos_token = tokenizer.encode_plus( positive, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) pos_row = pd.DataFrame( { "label": 1, "input_ids": pos_token["input_ids"].tolist(), "attention_mask": pos_token["attention_mask"].tolist(), } ) df = pd.concat([df, pos_row]) for neg in negatives: neg_token = tokenizer.encode_plus( neg, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) neg_row = pd.DataFrame( { "label": 2, "input_ids": neg_token["input_ids"].tolist(), "attention_mask": neg_token["attention_mask"].tolist(), } ) df = pd.concat([df, neg_row]) label = torch.tensor(df["label"].tolist()) input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) padded_input_ids = torch.transpose(padded_input_ids, 0, 1) attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) return { "label": label, "input_ids": padded_input_ids, "attention_mask": padded_attention_mask, } def create_dataloader_simcse( dataset, tokenizer, shuffle, ): return DataLoader( dataset, batch_size=config.batch_size_simcse, shuffle=shuffle, num_workers=config.num_workers, collate_fn=lambda batch: collate_simcse( batch, tokenizer, ), ) class ContrastiveLearningDataModule_simcse(pl.LightningDataModule): def __init__( self, train_df, val_df, tokenizer, ): super().__init__() self.train_df = train_df self.val_df = val_df self.tokenizer = tokenizer def setup(self, stage=None): self.train_dataset = ContrastiveLearningDataset(self.train_df) self.val_dataset = ContrastiveLearningDataset(self.val_df) def train_dataloader(self): return create_dataloader_simcse( self.train_dataset, self.tokenizer, shuffle=True, ) def val_dataloader(self): return create_dataloader_simcse( self.val_dataset, self.tokenizer, shuffle=False, ) ##### SimCSE_w def collate_simcse_w( batch, current_df, query_df, tokenizer, sim_df, all_d, ): """ Anchor: 0 Positive: 1 Negative: 2 """ anchor = batch[0] positive = anchor[:] negatives = batch[1:] df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"]) anchor_token = tokenizer.encode_plus( anchor, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) anchor_row = pd.DataFrame( { "label": 0, "input_ids": anchor_token["input_ids"].tolist(), "attention_mask": anchor_token["attention_mask"].tolist(), "score": 1, } ) df = pd.concat([df, anchor_row]) pos_token = tokenizer.encode_plus( positive, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) pos_row = pd.DataFrame( { "label": 1, "input_ids": pos_token["input_ids"].tolist(), "attention_mask": pos_token["attention_mask"].tolist(), "score": 1, } ) df = pd.concat([df, pos_row]) for neg in negatives: neg_token = tokenizer.encode_plus( neg, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d) offset = 8 score = score + offset neg_row = pd.DataFrame( { "label": 2, "input_ids": neg_token["input_ids"].tolist(), "attention_mask": neg_token["attention_mask"].tolist(), "score": score, } ) df = pd.concat([df, neg_row]) label = torch.tensor(df["label"].tolist()) input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) padded_input_ids = torch.transpose(padded_input_ids, 0, 1) attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) score = torch.tensor(df["score"].tolist()) return { "label": label, "input_ids": padded_input_ids, "attention_mask": padded_attention_mask, "score": score, } def create_dataloader_simcse_w( dataset, current_df, query_df, tokenizer, sim_df, all_d, shuffle, ): return DataLoader( dataset, batch_size=config.batch_size_simcse, shuffle=shuffle, num_workers=config.num_workers, collate_fn=lambda batch: collate_simcse_w( batch, current_df, query_df, tokenizer, sim_df, all_d, ), ) class ContrastiveLearningDataModule_simcse_w(pl.LightningDataModule): def __init__( self, train_df, val_df, query_df, tokenizer, sim_df, all_d, ): super().__init__() self.train_df = train_df self.val_df = val_df self.query_df = query_df self.tokenizer = tokenizer self.sim_df = sim_df self.all_d = all_d def setup(self, stage=None): self.train_dataset = ContrastiveLearningDataset(self.train_df) self.val_dataset = ContrastiveLearningDataset(self.val_df) def train_dataloader(self): return create_dataloader_simcse_w( self.train_dataset, self.train_df, self.query_df, self.tokenizer, self.sim_df, self.all_d, shuffle=True, ) def val_dataloader(self): return create_dataloader_simcse_w( self.val_dataset, self.val_df, self.query_df, self.tokenizer, self.sim_df, self.all_d, shuffle=False, ) ##### Samp def collate_samp( sentence, current_df, query_df, tokenizer, dictionary, sim_df, ): anchor = sentence[0] positives = positive_generator( anchor, current_df, query_df, dictionary, num_pos=config.num_pos ) negatives = negative_generator( anchor, current_df, query_df, dictionary, sim_df, num_neg=config.num_neg, ) df = pd.DataFrame(columns=["label", "input_ids", "attention_mask"]) anchor_token = tokenizer.encode_plus( anchor, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) anchor_row = pd.DataFrame( { "label": 0, "input_ids": anchor_token["input_ids"].tolist(), "attention_mask": anchor_token["attention_mask"].tolist(), } ) df = pd.concat([df, anchor_row]) for pos in positives: token = tokenizer.encode_plus( pos, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) row = pd.DataFrame( { "label": 1, "input_ids": token["input_ids"].tolist(), "attention_mask": token["attention_mask"].tolist(), } ) df = pd.concat([df, row]) for neg in negatives: token = tokenizer.encode_plus( neg, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) row = pd.DataFrame( { "label": 2, "input_ids": token["input_ids"].tolist(), "attention_mask": token["attention_mask"].tolist(), } ) df = pd.concat([df, row]) label = torch.tensor(df["label"].tolist()) input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) padded_input_ids = torch.transpose(padded_input_ids, 0, 1) attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) return { "label": label, "input_ids": padded_input_ids, "attention_mask": padded_attention_mask, } def create_dataloader_samp( dataset, current_df, query_df, tokenizer, dictionary, sim_df, shuffle, ): return DataLoader( dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_workers, collate_fn=lambda batch: collate_samp( batch, current_df, query_df, tokenizer, dictionary, sim_df, ), ) class ContrastiveLearningDataModule_samp(pl.LightningDataModule): def __init__( self, train_df, val_df, query_df, tokenizer, dictionary, sim_df, ): super().__init__() self.train_df = train_df self.val_df = val_df self.query_df = query_df self.tokenizer = tokenizer self.dictionary = dictionary self.sim_df = sim_df def setup(self, stage=None): self.train_dataset = ContrastiveLearningDataset(self.train_df) self.val_dataset = ContrastiveLearningDataset(self.val_df) def train_dataloader(self): return create_dataloader_samp( self.train_dataset, self.train_df, self.query_df, self.tokenizer, self.dictionary, self.sim_df, shuffle=True, ) def val_dataloader(self): return create_dataloader_samp( self.val_dataset, self.val_df, self.query_df, self.tokenizer, self.dictionary, self.sim_df, shuffle=False, ) ##### Samp_w def collate_samp_w( sentence, current_df, query_df, tokenizer, dictionary, sim_df, all_d, ): """ Anchor: 0 Positive: 1 Negative: 2 """ anchor = sentence[0] positives = positive_generator( anchor, current_df, query_df, dictionary, num_pos=config.num_pos ) negatives = negative_generator( anchor, current_df, query_df, dictionary, sim_df, num_neg=config.num_neg, ) df = pd.DataFrame(columns=["label", "input_ids", "attention_mask", "score"]) anchor_token = tokenizer.encode_plus( anchor, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) anchor_row = pd.DataFrame( { "label": 0, "input_ids": anchor_token["input_ids"].tolist(), "attention_mask": anchor_token["attention_mask"].tolist(), "score": 1, } ) df = pd.concat([df, anchor_row]) for pos in positives: token = tokenizer.encode_plus( pos, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) row = pd.DataFrame( { "label": 1, "input_ids": token["input_ids"].tolist(), "attention_mask": token["attention_mask"].tolist(), "score": 1, } ) df = pd.concat([df, row]) for neg in negatives: token = tokenizer.encode_plus( neg, return_token_type_ids=False, return_attention_mask=True, return_tensors="pt", ) score = max_pairwise_sim(anchor, neg, current_df, query_df, sim_df, all_d) offset = 8 # all negative scores start with 8 to distinguish from the positives score = score + offset row = pd.DataFrame( { "label": 2, "input_ids": token["input_ids"].tolist(), "attention_mask": token["attention_mask"].tolist(), "score": score, } ) df = pd.concat([df, row]) label = torch.tensor(df["label"].tolist()) input_ids_tsr = list(map(lambda x: torch.tensor(x), df["input_ids"])) padded_input_ids = pad_sequence(input_ids_tsr, padding_value=tokenizer.pad_token_id) padded_input_ids = torch.transpose(padded_input_ids, 0, 1) attention_mask_tsr = list(map(lambda x: torch.tensor(x), df["attention_mask"])) padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) score = torch.tensor(df["score"].tolist()) return { "label": label, "input_ids": padded_input_ids, "attention_mask": padded_attention_mask, "score": score, } def create_dataloader_samp_w( dataset, current_df, query_df, tokenizer, dictionary, sim_df, all_d, shuffle, ): return DataLoader( dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_workers, collate_fn=lambda batch: collate_samp_w( batch, current_df, query_df, tokenizer, dictionary, sim_df, all_d, ), ) class ContrastiveLearningDataModule_samp_w(pl.LightningDataModule): def __init__( self, train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d, ): super().__init__() self.train_df = train_df self.val_df = val_df self.query_df = query_df self.tokenizer = tokenizer self.dictionary = dictionary self.sim_df = sim_df self.all_d = all_d def setup(self, stage=None): self.train_dataset = ContrastiveLearningDataset(self.train_df) self.val_dataset = ContrastiveLearningDataset(self.val_df) def train_dataloader(self): return create_dataloader_samp_w( self.train_dataset, self.train_df, self.query_df, self.tokenizer, self.dictionary, self.sim_df, self.all_d, shuffle=True, ) def val_dataloader(self): return create_dataloader_samp_w( self.val_dataset, self.val_df, self.query_df, self.tokenizer, self.dictionary, self.sim_df, self.all_d, shuffle=False, ) #### Test from transformers import AutoTokenizer from ast import literal_eval from sklearn.model_selection import train_test_split query_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv" ) query_df["concepts"] = query_df["concepts"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply( lambda x: [val for val in x if val is not None] ) # remove None in lists query_df = query_df.drop(columns=["one_hot"]) train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") sim_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv" ) all_d = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv" ) all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) d1 = ContrastiveLearningDataModule_simcse(train_df, val_df, tokenizer) d1.setup() train_d1 = d1.train_dataloader() for batch in train_d1: b1 = batch break d2 = ContrastiveLearningDataModule_simcse_w( train_df, val_df, query_df, tokenizer, sim_df, all_d ) d2.setup() train_d2 = d2.train_dataloader() for batch in train_d2: b2 = batch break d3 = ContrastiveLearningDataModule_samp( train_df, val_df, query_df, tokenizer, dictionary, sim_df ) d3.setup() train_d3 = d3.train_dataloader() for batch in train_d3: b3 = batch break d4 = ContrastiveLearningDataModule_samp_w( train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d ) d4.setup() train_d4 = d4.train_dataloader() for batch in train_d4: b4 = batch break