CHOPT / train.py
sxtforreal's picture
Create train.py
3957f36 verified
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import lightning.pytorch as pl
from pytorch_lightning.loggers import TensorBoardLogger
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from ast import literal_eval
# imports from our own modules
import config
from model import (
BERTContrastiveLearning_simcse,
BERTContrastiveLearning_simcse_w,
BERTContrastiveLearning_samp,
BERTContrastiveLearning_samp_w,
)
from dataset import (
ContrastiveLearningDataModule_simcse,
ContrastiveLearningDataModule_simcse_w,
ContrastiveLearningDataModule_samp,
ContrastiveLearningDataModule_samp_w,
)
if __name__ == "__main__":
seed_everything(0, workers=True)
# Initialize tensorboard logger
logger = TensorBoardLogger("logs", name="MIMIC-tr")
query_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/mimic_data/processed_train/processed.csv"
)
# query_df = query_df.head(1000)
query_df["concepts"] = query_df["concepts"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(
lambda x: [val for val in x if val is not None]
) # remove None in lists
query_df = query_df.drop(columns=["one_hot"])
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
sim_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairwise_scores.csv"
)
all_d = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/all_d_full.csv"
)
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
# SimCSE
data_module1 = ContrastiveLearningDataModule_simcse(
train_df,
val_df,
tokenizer,
)
data_module1.setup()
print("Number of training data:", len(data_module1.train_dataset))
print("Number of validation data:", len(data_module1.val_dataset))
model1 = BERTContrastiveLearning_simcse(
n_batches=len(data_module1.train_dataset) / config.batch_size,
n_epochs=config.max_epochs,
lr=config.learning_rate,
unfreeze=config.unfreeze_ratio,
)
checkpoint1 = ModelCheckpoint(
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse/v1",
filename="{epoch}-{step}",
# save_weights_only=True,
save_last=True,
every_n_train_steps=config.log_every_n_steps,
monitor=None,
save_top_k=-1,
)
trainer1 = pl.Trainer(
accelerator=config.accelerator,
devices=config.devices,
strategy="ddp",
logger=logger,
max_epochs=config.max_epochs,
min_epochs=config.min_epochs,
precision=config.precision,
callbacks=[
EarlyStopping(
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
),
checkpoint1,
],
profiler="simple",
log_every_n_steps=config.log_every_n_steps,
)
trainer1.fit(model1, data_module1)
# SimCSE_w
data_module2 = ContrastiveLearningDataModule_simcse_w(
train_df,
val_df,
query_df,
tokenizer,
sim_df,
all_d,
)
data_module2.setup()
print("Number of training data:", len(data_module2.train_dataset))
print("Number of validation data:", len(data_module2.val_dataset))
model2 = BERTContrastiveLearning_simcse_w(
n_batches=len(data_module2.train_dataset) / config.batch_size,
n_epochs=config.max_epochs,
lr=config.learning_rate,
unfreeze=config.unfreeze_ratio,
)
checkpoint2 = ModelCheckpoint(
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/simcse_w/v1",
filename="{epoch}-{step}",
# save_weights_only=True,
save_last=True,
every_n_train_steps=config.log_every_n_steps,
monitor=None,
save_top_k=-1,
)
trainer2 = pl.Trainer(
accelerator=config.accelerator,
devices=config.devices,
strategy="ddp",
logger=logger,
max_epochs=config.max_epochs,
min_epochs=config.min_epochs,
precision=config.precision,
callbacks=[
EarlyStopping(
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
),
checkpoint2,
],
profiler="simple",
log_every_n_steps=config.log_every_n_steps,
)
trainer2.fit(model2, data_module2)
# Samp
data_module3 = ContrastiveLearningDataModule_samp(
train_df,
val_df,
query_df,
tokenizer,
dictionary,
sim_df,
)
data_module3.setup()
print("Number of training data:", len(data_module3.train_dataset))
print("Number of validation data:", len(data_module3.val_dataset))
model3 = BERTContrastiveLearning_samp(
n_batches=len(data_module3.train_dataset) / config.batch_size,
n_epochs=config.max_epochs,
lr=config.learning_rate,
unfreeze=config.unfreeze_ratio,
)
checkpoint3 = ModelCheckpoint(
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp/v1",
filename="{epoch}-{step}",
# save_weights_only=True,
save_last=True,
every_n_train_steps=config.log_every_n_steps,
monitor=None,
save_top_k=-1,
)
trainer3 = pl.Trainer(
accelerator=config.accelerator,
devices=config.devices,
strategy="ddp",
logger=logger,
max_epochs=config.max_epochs,
min_epochs=config.min_epochs,
precision=config.precision,
callbacks=[
EarlyStopping(
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
),
checkpoint3,
],
profiler="simple",
log_every_n_steps=config.log_every_n_steps,
)
trainer3.fit(model3, data_module3)
# Samp_w
data_module4 = ContrastiveLearningDataModule_samp_w(
train_df,
val_df,
query_df,
tokenizer,
dictionary,
sim_df,
all_d,
)
data_module4.setup()
print("Number of training data:", len(data_module4.train_dataset))
print("Number of validation data:", len(data_module4.val_dataset))
model4 = BERTContrastiveLearning_samp_w(
n_batches=len(data_module4.train_dataset) / config.batch_size,
n_epochs=config.max_epochs,
lr=config.learning_rate,
unfreeze=config.unfreeze_ratio,
)
checkpoint4 = ModelCheckpoint(
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/samp_w/v1",
filename="{epoch}-{step}",
# save_weights_only=True,
save_last=True,
every_n_train_steps=config.log_every_n_steps,
monitor=None,
save_top_k=-1,
)
trainer4 = pl.Trainer(
accelerator=config.accelerator,
devices=config.devices,
strategy="ddp",
logger=logger,
max_epochs=config.max_epochs,
min_epochs=config.min_epochs,
precision=config.precision,
callbacks=[
EarlyStopping(
monitor="validation_loss", min_delta=1e-3, patience=3, mode="min"
),
checkpoint4,
],
profiler="simple",
log_every_n_steps=config.log_every_n_steps,
)
trainer4.fit(model4, data_module4)