|
import torch |
|
from transformers import AutoTokenizer |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.utils.rnn import pad_sequence |
|
import lightning.pytorch as pl |
|
import config |
|
import pandas as pd |
|
import copy |
|
from ast import literal_eval |
|
from sklearn.model_selection import train_test_split |
|
import sys |
|
|
|
sys.path.append("/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag") |
|
from data_proc.data_gen import ( |
|
positive_generator, |
|
positive_generator_alter, |
|
negative_generator, |
|
negative_generator_alter, |
|
negative_generator_random, |
|
negative_generator_v2, |
|
get_mentioned_code, |
|
) |
|
|
|
|
|
def tokenize(text, tokenizer, tag): |
|
inputs = tokenizer( |
|
text, |
|
return_token_type_ids=False, |
|
return_tensors="pt", |
|
) |
|
|
|
inputs["input_ids"] = inputs["input_ids"][0] |
|
inputs["attention_mask"] = inputs["attention_mask"][0] |
|
inputs["mlm_ids"] = copy.deepcopy(inputs["input_ids"]) |
|
inputs["mlm_labels"] = copy.deepcopy(inputs["input_ids"]) |
|
|
|
tokens_to_ignore = torch.tensor([101, 102, 0]) |
|
valid_tokens = inputs["input_ids"][ |
|
~torch.isin(inputs["input_ids"], tokens_to_ignore) |
|
] |
|
num_of_token_to_mask = int(len(valid_tokens) * config.mask_pct) |
|
token_to_mask = valid_tokens[ |
|
torch.randperm(valid_tokens.size(0))[:num_of_token_to_mask] |
|
] |
|
inputs["mlm_ids"] = [ |
|
103 if x in token_to_mask else x for x in inputs["mlm_ids"] |
|
] |
|
inputs["mlm_labels"] = [ |
|
y if y in token_to_mask else -100 for y in inputs["mlm_labels"] |
|
] |
|
|
|
inputs["mlm_ids"] = torch.tensor(inputs["mlm_ids"]) |
|
inputs["mlm_labels"] = torch.tensor(inputs["mlm_labels"]) |
|
if tag == "A": |
|
inputs["tag"] = 0 |
|
elif tag == "P": |
|
inputs["tag"] = 1 |
|
elif tag == "N": |
|
inputs["tag"] = 2 |
|
return inputs |
|
|
|
|
|
class CLDataset(Dataset): |
|
def __init__( |
|
self, |
|
data: pd.DataFrame, |
|
): |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
data_row = self.data.iloc[index] |
|
sentence = data_row.sentences |
|
return sentence |
|
|
|
|
|
def collate_func(batch, tokenizer, current_df, query_df, dictionary, all_d): |
|
|
|
anchor = batch[0] |
|
positives = positive_generator_alter( |
|
anchor, |
|
current_df, |
|
dictionary, |
|
num_pos=config.num_pos, |
|
) |
|
negatives = negative_generator_v2( |
|
anchor, |
|
current_df, |
|
query_df, |
|
all_d, |
|
num_neg=config.num_neg, |
|
) |
|
|
|
inputs = [] |
|
|
|
anchor_dict = tokenize(anchor, tokenizer, "A") |
|
inputs.append(anchor_dict) |
|
|
|
for pos in positives: |
|
pos_dict = tokenize(pos, tokenizer, "P") |
|
inputs.append(pos_dict) |
|
|
|
for neg in negatives: |
|
neg_dict = tokenize(neg, tokenizer, "N") |
|
inputs.append(neg_dict) |
|
|
|
tags = torch.tensor([d["tag"] for d in inputs]) |
|
|
|
input_ids_tsr = [d["input_ids"] for d in inputs] |
|
padded_input_ids = pad_sequence(input_ids_tsr, padding_value=0) |
|
padded_input_ids = torch.transpose(padded_input_ids, 0, 1) |
|
|
|
attention_mask_tsr = [d["attention_mask"] for d in inputs] |
|
padded_attention_mask = pad_sequence(attention_mask_tsr, padding_value=0) |
|
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) |
|
|
|
mlm_ids_tsr = [d["mlm_ids"] for d in inputs] |
|
padded_mlm_ids = pad_sequence(mlm_ids_tsr, padding_value=0) |
|
padded_mlm_ids = torch.transpose(padded_mlm_ids, 0, 1) |
|
|
|
mlm_labels_tsr = [d["mlm_labels"] for d in inputs] |
|
padded_mlm_labels = pad_sequence(mlm_labels_tsr, padding_value=-100) |
|
padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1) |
|
|
|
return { |
|
"tags": tags, |
|
"input_ids": padded_input_ids, |
|
"attention_mask": padded_attention_mask, |
|
"mlm_ids": padded_mlm_ids, |
|
"mlm_labels": padded_mlm_labels, |
|
} |
|
|
|
|
|
def create_dataloader( |
|
dataset, tokenizer, shuffle, current_df, query_df, dictionary, all_d |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=config.batch_size, |
|
shuffle=shuffle, |
|
num_workers=1, |
|
collate_fn=lambda batch: collate_func( |
|
batch, tokenizer, current_df, query_df, dictionary, all_d |
|
), |
|
) |
|
|
|
|
|
class CLDataModule(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_df, |
|
val_df, |
|
tokenizer, |
|
query_df, |
|
dictionary, |
|
all_d, |
|
): |
|
super().__init__() |
|
self.train_df = train_df |
|
self.val_df = val_df |
|
self.tokenizer = tokenizer |
|
self.query_df = query_df |
|
self.dictionary = dictionary |
|
self.all_d = all_d |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = CLDataset(self.train_df) |
|
self.val_dataset = CLDataset(self.val_df) |
|
|
|
def train_dataloader(self): |
|
return create_dataloader( |
|
self.train_dataset, |
|
self.tokenizer, |
|
shuffle=True, |
|
current_df=self.train_df, |
|
query_df=self.query_df, |
|
dictionary=self.dictionary, |
|
all_d=self.all_d, |
|
) |
|
|
|
def val_dataloader(self): |
|
return create_dataloader( |
|
self.val_dataset, |
|
self.tokenizer, |
|
shuffle=False, |
|
current_df=self.val_df, |
|
query_df=self.query_df, |
|
dictionary=self.dictionary, |
|
all_d=self.all_d, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
query_df = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" |
|
) |
|
query_df["concepts"] = query_df["concepts"].apply(literal_eval) |
|
query_df["codes"] = query_df["codes"].apply(literal_eval) |
|
query_df["codes"] = query_df["codes"].apply( |
|
lambda x: [val for val in x if val is not None] |
|
) |
|
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) |
|
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
|
all_d = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" |
|
) |
|
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) |
|
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) |
|
all_d["finding_sites"] = all_d["finding_sites"].apply(literal_eval) |
|
all_d["morphology"] = all_d["morphology"].apply(literal_eval) |
|
dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) |
|
|
|
d = CLDataModule(train_df, val_df, tokenizer, query_df, dictionary, all_d) |
|
d.setup() |
|
train = d.train_dataloader() |
|
for batch in train: |
|
b = batch |
|
break |
|
|