from pathlib import Path from datetime import datetime import torch from torch.utils.data import ConcatDataset from pytorch_lightning.trainer import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from medical_diffusion.data.datamodules import SimpleDataModule from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') if __name__ == "__main__": # --------------- Settings -------------------- current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") path_run_dir = Path.cwd() / 'runs' / str(current_time) path_run_dir.mkdir(parents=True, exist_ok=True) gpus = [0] if torch.cuda.is_available() else None # ------------ Load Data ---------------- # ds_1 = AIROGSDataset( # 256x256 # crawler_ext='jpg', # augment_horizontal_flip=True, # augment_vertical_flip=True, # # path_root='/home/gustav/Documents/datasets/AIROGS/dataset', # path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', # ) # ds_2 = MSIvsMSS_2_Dataset( # 512x512 # # image_resize=256, # crawler_ext='jpg', # augment_horizontal_flip=True, # augment_vertical_flip=True, # # path_root='/home/gustav/Documents/datasets/Kather_2/train' # path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/' # ) ds_3 = CheXpert_2_Dataset( # 256x256 # image_resize=128, augment_horizontal_flip=False, augment_vertical_flip=False, # path_root = '/home/gustav/Documents/datasets/CheXpert/preprocessed_tianyu' path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' ) # ds = ConcatDataset([ds_1, ds_2, ds_3]) dm = SimpleDataModule( ds_train = ds_3, batch_size=8, # num_workers=0, pin_memory=True ) # ------------ Initialize Model ------------ model = VAE( in_channels=3, out_channels=3, emb_channels=8, spatial_dims=2, hid_chs = [ 64, 128, 256, 512], kernel_sizes=[ 3, 3, 3, 3], strides = [ 1, 2, 2, 2], deep_supervision=1, use_attention= 'none', loss = torch.nn.MSELoss, # optimizer_kwargs={'lr':1e-6}, embedding_loss_weight=1e-6 ) # model.load_pretrained(Path.cwd()/'runs/2022_12_01_183752_patho_vae/last.ckpt', strict=True) # model = VAEGAN( # in_channels=3, # out_channels=3, # emb_channels=8, # spatial_dims=2, # hid_chs = [ 64, 128, 256, 512], # deep_supervision=1, # use_attention= 'none', # start_gan_train_step=-1, # embedding_loss_weight=1e-6 # ) # model.vqvae.load_pretrained(Path.cwd()/'runs/2022_11_25_082209_chest_vae/last.ckpt') # model.load_pretrained(Path.cwd()/'runs/2022_11_25_232957_patho_vaegan/last.ckpt') # model = VQVAE( # in_channels=3, # out_channels=3, # emb_channels=4, # num_embeddings = 8192, # spatial_dims=2, # hid_chs = [64, 128, 256, 512], # embedding_loss_weight=1, # beta=1, # loss = torch.nn.L1Loss, # deep_supervision=1, # use_attention = 'none', # ) # model = VQGAN( # in_channels=3, # out_channels=3, # emb_channels=4, # num_embeddings = 8192, # spatial_dims=2, # hid_chs = [64, 128, 256, 512], # embedding_loss_weight=1, # beta=1, # start_gan_train_step=-1, # pixel_loss = torch.nn.L1Loss, # deep_supervision=1, # use_attention='none', # ) # model.vqvae.load_pretrained(Path.cwd()/'runs/2022_12_13_093727_patho_vqvae/last.ckpt') # -------------- Training Initialization --------------- to_monitor = "train/L1" # "val/loss" min_max = "min" save_and_sample_every = 50 early_stopping = EarlyStopping( monitor=to_monitor, min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement patience=30, # number of checks with no improvement mode=min_max ) checkpointing = ModelCheckpoint( dirpath=str(path_run_dir), # dirpath monitor=to_monitor, every_n_train_steps=save_and_sample_every, save_last=True, save_top_k=5, mode=min_max, ) trainer = Trainer( accelerator='gpu', devices=[0], # precision=16, # amp_backend='apex', # amp_level='O2', # gradient_clip_val=0.5, default_root_dir=str(path_run_dir), callbacks=[checkpointing], # callbacks=[checkpointing, early_stopping], enable_checkpointing=True, check_val_every_n_epoch=1, log_every_n_steps=save_and_sample_every, auto_lr_find=False, # limit_train_batches=1000, limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available min_epochs=100, max_epochs=1001, num_sanity_val_steps=2, ) # ---------------- Execute Training ---------------- trainer.fit(model, datamodule=dm) # ------------- Save path to best model ------------- model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path)