from lightning.pytorch import seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks.early_stopping import EarlyStopping import lightning.pytorch as pl from pytorch_lightning.loggers import TensorBoardLogger import pandas as pd from sklearn.model_selection import train_test_split from transformers import AutoTokenizer from ast import literal_eval # imports from our own modules import config from model import ( BERTContrastiveLearning_simcse, BERTContrastiveLearning_simcse_w, BERTContrastiveLearning_samp, BERTContrastiveLearning_samp_w, ) from dataset import ( ContrastiveLearningDataModule_simcse, ContrastiveLearningDataModule_simcse_w, ContrastiveLearningDataModule_samp, ContrastiveLearningDataModule_samp_w, ) if __name__ == "__main__": seed_everything(0, workers=True) # Initialize tensorboard logger logger = TensorBoardLogger("logs", name="MIMIC-tr") query_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv" ) # query_df = query_df.head(1000) query_df["concepts"] = query_df["concepts"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply( lambda x: [val for val in x if val is not None] ) # remove None in lists query_df = query_df.drop(columns=["one_hot"]) train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") sim_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv" ) all_d = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv" ) all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) # SimCSE data_module1 = ContrastiveLearningDataModule_simcse( train_df, val_df, tokenizer, ) data_module1.setup() print("Number of training data:", len(data_module1.train_dataset)) print("Number of validation data:", len(data_module1.val_dataset)) model1 = BERTContrastiveLearning_simcse( n_batches=len(data_module1.train_dataset) / config.batch_size, n_epochs=config.max_epochs, lr=config.learning_rate, unfreeze=config.unfreeze_ratio, ) checkpoint1 = ModelCheckpoint( dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse/v1", filename="{epoch}-{step}", # save_weights_only=True, save_last=True, every_n_train_steps=config.log_every_n_steps, monitor=None, save_top_k=-1, ) trainer1 = pl.Trainer( accelerator=config.accelerator, devices=config.devices, strategy="ddp", logger=logger, max_epochs=config.max_epochs, min_epochs=config.min_epochs, precision=config.precision, callbacks=[ EarlyStopping( monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" ), checkpoint1, ], profiler="simple", log_every_n_steps=config.log_every_n_steps, ) trainer1.fit(model1, data_module1) # SimCSE_w data_module2 = ContrastiveLearningDataModule_simcse_w( train_df, val_df, query_df, tokenizer, sim_df, all_d, ) data_module2.setup() print("Number of training data:", len(data_module2.train_dataset)) print("Number of validation data:", len(data_module2.val_dataset)) model2 = BERTContrastiveLearning_simcse_w( n_batches=len(data_module2.train_dataset) / config.batch_size, n_epochs=config.max_epochs, lr=config.learning_rate, unfreeze=config.unfreeze_ratio, ) checkpoint2 = ModelCheckpoint( dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse_w/v1", filename="{epoch}-{step}", # save_weights_only=True, save_last=True, every_n_train_steps=config.log_every_n_steps, monitor=None, save_top_k=-1, ) trainer2 = pl.Trainer( accelerator=config.accelerator, devices=config.devices, strategy="ddp", logger=logger, max_epochs=config.max_epochs, min_epochs=config.min_epochs, precision=config.precision, callbacks=[ EarlyStopping( monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" ), checkpoint2, ], profiler="simple", log_every_n_steps=config.log_every_n_steps, ) trainer2.fit(model2, data_module2) # Samp data_module3 = ContrastiveLearningDataModule_samp( train_df, val_df, query_df, tokenizer, dictionary, sim_df, ) data_module3.setup() print("Number of training data:", len(data_module3.train_dataset)) print("Number of validation data:", len(data_module3.val_dataset)) model3 = BERTContrastiveLearning_samp( n_batches=len(data_module3.train_dataset) / config.batch_size, n_epochs=config.max_epochs, lr=config.learning_rate, unfreeze=config.unfreeze_ratio, ) checkpoint3 = ModelCheckpoint( dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp/v1", filename="{epoch}-{step}", # save_weights_only=True, save_last=True, every_n_train_steps=config.log_every_n_steps, monitor=None, save_top_k=-1, ) trainer3 = pl.Trainer( accelerator=config.accelerator, devices=config.devices, strategy="ddp", logger=logger, max_epochs=config.max_epochs, min_epochs=config.min_epochs, precision=config.precision, callbacks=[ EarlyStopping( monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" ), checkpoint3, ], profiler="simple", log_every_n_steps=config.log_every_n_steps, ) trainer3.fit(model3, data_module3) # Samp_w data_module4 = ContrastiveLearningDataModule_samp_w( train_df, val_df, query_df, tokenizer, dictionary, sim_df, all_d, ) data_module4.setup() print("Number of training data:", len(data_module4.train_dataset)) print("Number of validation data:", len(data_module4.val_dataset)) model4 = BERTContrastiveLearning_samp_w( n_batches=len(data_module4.train_dataset) / config.batch_size, n_epochs=config.max_epochs, lr=config.learning_rate, unfreeze=config.unfreeze_ratio, ) checkpoint4 = ModelCheckpoint( dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp_w/v1", filename="{epoch}-{step}", # save_weights_only=True, save_last=True, every_n_train_steps=config.log_every_n_steps, monitor=None, save_top_k=-1, ) trainer4 = pl.Trainer( accelerator=config.accelerator, devices=config.devices, strategy="ddp", logger=logger, max_epochs=config.max_epochs, min_epochs=config.min_epochs, precision=config.precision, callbacks=[ EarlyStopping( monitor="validation_loss", min_delta=1e-3, patience=3, mode="min" ), checkpoint4, ], profiler="simple", log_every_n_steps=config.log_every_n_steps, ) trainer4.fit(model4, data_module4)