|
from lightning.pytorch import seed_everything |
|
from lightning.pytorch.callbacks import ModelCheckpoint |
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
|
import lightning.pytorch as pl |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
from transformers import AutoTokenizer |
|
from ast import literal_eval |
|
|
|
|
|
import config |
|
from model import ( |
|
BERTContrastiveLearning_simcse, |
|
BERTContrastiveLearning_simcse_w, |
|
BERTContrastiveLearning_samp, |
|
BERTContrastiveLearning_samp_w, |
|
) |
|
from dataset import ( |
|
ContrastiveLearningDataModule_simcse, |
|
ContrastiveLearningDataModule_simcse_w, |
|
ContrastiveLearningDataModule_samp, |
|
ContrastiveLearningDataModule_samp_w, |
|
) |
|
|
|
if __name__ == "__main__": |
|
seed_everything(0, workers=True) |
|
|
|
|
|
logger = TensorBoardLogger("logs", name="MIMIC-tr") |
|
|
|
query_df = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv" |
|
) |
|
|
|
query_df["concepts"] = query_df["concepts"].apply(literal_eval) |
|
query_df["codes"] = query_df["codes"].apply(literal_eval) |
|
query_df["codes"] = query_df["codes"].apply( |
|
lambda x: [val for val in x if val is not None] |
|
) |
|
query_df = query_df.drop(columns=["one_hot"]) |
|
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
|
sim_df = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv" |
|
) |
|
|
|
all_d = pd.read_csv( |
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv" |
|
) |
|
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) |
|
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) |
|
dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) |
|
|
|
|
|
data_module1 = ContrastiveLearningDataModule_simcse( |
|
train_df, |
|
val_df, |
|
tokenizer, |
|
) |
|
data_module1.setup() |
|
|
|
print("Number of training data:", len(data_module1.train_dataset)) |
|
print("Number of validation data:", len(data_module1.val_dataset)) |
|
|
|
model1 = BERTContrastiveLearning_simcse( |
|
n_batches=len(data_module1.train_dataset) / config.batch_size, |
|
n_epochs=config.max_epochs, |
|
lr=config.learning_rate, |
|
unfreeze=config.unfreeze_ratio, |
|
) |
|
|
|
checkpoint1 = ModelCheckpoint( |
|
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse/v1", |
|
filename="{epoch}-{step}", |
|
|
|
save_last=True, |
|
every_n_train_steps=config.log_every_n_steps, |
|
monitor=None, |
|
save_top_k=-1, |
|
) |
|
|
|
trainer1 = pl.Trainer( |
|
accelerator=config.accelerator, |
|
devices=config.devices, |
|
strategy="ddp", |
|
logger=logger, |
|
max_epochs=config.max_epochs, |
|
min_epochs=config.min_epochs, |
|
precision=config.precision, |
|
callbacks=[ |
|
EarlyStopping( |
|
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" |
|
), |
|
checkpoint1, |
|
], |
|
profiler="simple", |
|
log_every_n_steps=config.log_every_n_steps, |
|
) |
|
|
|
trainer1.fit(model1, data_module1) |
|
|
|
|
|
data_module2 = ContrastiveLearningDataModule_simcse_w( |
|
train_df, |
|
val_df, |
|
query_df, |
|
tokenizer, |
|
sim_df, |
|
all_d, |
|
) |
|
data_module2.setup() |
|
|
|
print("Number of training data:", len(data_module2.train_dataset)) |
|
print("Number of validation data:", len(data_module2.val_dataset)) |
|
|
|
model2 = BERTContrastiveLearning_simcse_w( |
|
n_batches=len(data_module2.train_dataset) / config.batch_size, |
|
n_epochs=config.max_epochs, |
|
lr=config.learning_rate, |
|
unfreeze=config.unfreeze_ratio, |
|
) |
|
|
|
checkpoint2 = ModelCheckpoint( |
|
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse_w/v1", |
|
filename="{epoch}-{step}", |
|
|
|
save_last=True, |
|
every_n_train_steps=config.log_every_n_steps, |
|
monitor=None, |
|
save_top_k=-1, |
|
) |
|
|
|
trainer2 = pl.Trainer( |
|
accelerator=config.accelerator, |
|
devices=config.devices, |
|
strategy="ddp", |
|
logger=logger, |
|
max_epochs=config.max_epochs, |
|
min_epochs=config.min_epochs, |
|
precision=config.precision, |
|
callbacks=[ |
|
EarlyStopping( |
|
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" |
|
), |
|
checkpoint2, |
|
], |
|
profiler="simple", |
|
log_every_n_steps=config.log_every_n_steps, |
|
) |
|
|
|
trainer2.fit(model2, data_module2) |
|
|
|
|
|
data_module3 = ContrastiveLearningDataModule_samp( |
|
train_df, |
|
val_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
) |
|
data_module3.setup() |
|
|
|
print("Number of training data:", len(data_module3.train_dataset)) |
|
print("Number of validation data:", len(data_module3.val_dataset)) |
|
|
|
model3 = BERTContrastiveLearning_samp( |
|
n_batches=len(data_module3.train_dataset) / config.batch_size, |
|
n_epochs=config.max_epochs, |
|
lr=config.learning_rate, |
|
unfreeze=config.unfreeze_ratio, |
|
) |
|
|
|
checkpoint3 = ModelCheckpoint( |
|
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp/v1", |
|
filename="{epoch}-{step}", |
|
|
|
save_last=True, |
|
every_n_train_steps=config.log_every_n_steps, |
|
monitor=None, |
|
save_top_k=-1, |
|
) |
|
|
|
trainer3 = pl.Trainer( |
|
accelerator=config.accelerator, |
|
devices=config.devices, |
|
strategy="ddp", |
|
logger=logger, |
|
max_epochs=config.max_epochs, |
|
min_epochs=config.min_epochs, |
|
precision=config.precision, |
|
callbacks=[ |
|
EarlyStopping( |
|
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" |
|
), |
|
checkpoint3, |
|
], |
|
profiler="simple", |
|
log_every_n_steps=config.log_every_n_steps, |
|
) |
|
|
|
trainer3.fit(model3, data_module3) |
|
|
|
|
|
data_module4 = ContrastiveLearningDataModule_samp_w( |
|
train_df, |
|
val_df, |
|
query_df, |
|
tokenizer, |
|
dictionary, |
|
sim_df, |
|
all_d, |
|
) |
|
data_module4.setup() |
|
|
|
print("Number of training data:", len(data_module4.train_dataset)) |
|
print("Number of validation data:", len(data_module4.val_dataset)) |
|
|
|
model4 = BERTContrastiveLearning_samp_w( |
|
n_batches=len(data_module4.train_dataset) / config.batch_size, |
|
n_epochs=config.max_epochs, |
|
lr=config.learning_rate, |
|
unfreeze=config.unfreeze_ratio, |
|
) |
|
|
|
checkpoint4 = ModelCheckpoint( |
|
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp_w/v1", |
|
filename="{epoch}-{step}", |
|
|
|
save_last=True, |
|
every_n_train_steps=config.log_every_n_steps, |
|
monitor=None, |
|
save_top_k=-1, |
|
) |
|
|
|
trainer4 = pl.Trainer( |
|
accelerator=config.accelerator, |
|
devices=config.devices, |
|
strategy="ddp", |
|
logger=logger, |
|
max_epochs=config.max_epochs, |
|
min_epochs=config.min_epochs, |
|
precision=config.precision, |
|
callbacks=[ |
|
EarlyStopping( |
|
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" |
|
), |
|
checkpoint4, |
|
], |
|
profiler="simple", |
|
log_every_n_steps=config.log_every_n_steps, |
|
) |
|
|
|
trainer4.fit(model4, data_module4) |
|
|