from email.mime import audio from pathlib import Path from datetime import datetime import torch import torch.nn as nn from pytorch_lightning.trainer import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint import numpy as np import torchio as tio from medical_diffusion.data.datamodules import SimpleDataModule from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset from medical_diffusion.models.pipelines import DiffusionPipeline from medical_diffusion.models.estimators import UNet from medical_diffusion.external.stable_diffusion.unet_openai import UNetModel from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler from medical_diffusion.models.embedders import LabelEmbedder, TimeEmbbeding from medical_diffusion.models.embedders.latent_embedders import VAE, VAEGAN, VQVAE, VQGAN import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') if __name__ == "__main__": # ------------ Load Data ---------------- # ds = AIROGSDataset( # crawler_ext='jpg', # augment_horizontal_flip = False, # augment_vertical_flip = False, # # path_root='/home/gustav/Documents/datasets/AIROGS/data_256x256/', # path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', # ) # ds = MSIvsMSS_2_Dataset( # crawler_ext='jpg', # image_resize=None, # image_crop=None, # augment_horizontal_flip=False, # augment_vertical_flip=False, # # path_root='/home/gustav/Documents/datasets/Kather_2/train', # path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', # ) ds = CheXpert_2_Dataset( # 256x256 augment_horizontal_flip=False, augment_vertical_flip=False, path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' ) dm = SimpleDataModule( ds_train = ds, batch_size=32, # num_workers=0, pin_memory=True, # weights=ds.get_weights() ) current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") path_run_dir = Path.cwd() / 'runs' / str(current_time) path_run_dir.mkdir(parents=True, exist_ok=True) accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' # ------------ Initialize Model ------------ # cond_embedder = None cond_embedder = LabelEmbedder cond_embedder_kwargs = { 'emb_dim': 1024, 'num_classes': 2 } time_embedder = TimeEmbbeding time_embedder_kwargs ={ 'emb_dim': 1024 # stable diffusion uses 4*model_channels (model_channels is about 256) } noise_estimator = UNet noise_estimator_kwargs = { 'in_ch':8, 'out_ch':8, 'spatial_dims':2, 'hid_chs': [ 256, 256, 512, 1024], 'kernel_sizes':[3, 3, 3, 3], 'strides': [1, 2, 2, 2], 'time_embedder':time_embedder, 'time_embedder_kwargs': time_embedder_kwargs, 'cond_embedder':cond_embedder, 'cond_embedder_kwargs': cond_embedder_kwargs, 'deep_supervision': False, 'use_res_block':True, 'use_attention':'none', } # ------------ Initialize Noise ------------ noise_scheduler = GaussianNoiseScheduler noise_scheduler_kwargs = { 'timesteps': 1000, 'beta_start': 0.002, # 0.0001, 0.0015 'beta_end': 0.02, # 0.01, 0.0195 'schedule_strategy': 'scaled_linear' } # ------------ Initialize Latent Space ------------ # latent_embedder = None # latent_embedder = VQVAE latent_embedder = VAE latent_embedder_checkpoint = 'runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt' # ------------ Initialize Pipeline ------------ pipeline = DiffusionPipeline( noise_estimator=noise_estimator, noise_estimator_kwargs=noise_estimator_kwargs, noise_scheduler=noise_scheduler, noise_scheduler_kwargs = noise_scheduler_kwargs, latent_embedder=latent_embedder, latent_embedder_checkpoint = latent_embedder_checkpoint, estimator_objective='x_T', estimate_variance=False, use_self_conditioning=False, use_ema=False, classifier_free_guidance_dropout=0.5, # Disable during training by setting to 0 do_input_centering=False, clip_x0=False, sample_every_n_steps=1000 ) # pipeline_old = pipeline.load_from_checkpoint('runs/2022_11_27_085654_chest_diffusion/last.ckpt') # pipeline.noise_estimator.load_state_dict(pipeline_old.noise_estimator.state_dict(), strict=True) # -------------- Training Initialization --------------- to_monitor = "train/loss" # "pl/val_loss" min_max = "min" save_and_sample_every = 100 early_stopping = EarlyStopping( monitor=to_monitor, min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement patience=30, # number of checks with no improvement mode=min_max ) checkpointing = ModelCheckpoint( dirpath=str(path_run_dir), # dirpath monitor=to_monitor, every_n_train_steps=save_and_sample_every, save_last=True, save_top_k=2, mode=min_max, ) trainer = Trainer( accelerator=accelerator, # devices=[0], # precision=16, # amp_backend='apex', # amp_level='O2', # gradient_clip_val=0.5, default_root_dir=str(path_run_dir), callbacks=[checkpointing], # callbacks=[checkpointing, early_stopping], enable_checkpointing=True, check_val_every_n_epoch=1, log_every_n_steps=save_and_sample_every, auto_lr_find=False, # limit_train_batches=1000, limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available min_epochs=100, max_epochs=1001, num_sanity_val_steps=2, ) # ---------------- Execute Training ---------------- trainer.fit(pipeline, datamodule=dm) # ------------- Save path to best model ------------- pipeline.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path)